Squashed 'subtree/rabit/' changes from b15f6cd..e08542c

e08542c fix doc
e95c962 remove I prefix from interface, serializable now takes in pointer

git-subtree-dir: subtree/rabit
git-subtree-split: e08542c6357bc044199e4876d9bea949d05f614c
This commit is contained in:
tqchen 2015-04-08 17:39:45 -07:00
parent 89244b4aec
commit 3d11f56880
29 changed files with 225 additions and 228 deletions

View File

@ -95,7 +95,7 @@ WARN_LOGFILE =
#--------------------------------------------------------------------------- #---------------------------------------------------------------------------
# configuration options related to the input files # configuration options related to the input files
#--------------------------------------------------------------------------- #---------------------------------------------------------------------------
INPUT = INPUT = . dmlc
INPUT_ENCODING = UTF-8 INPUT_ENCODING = UTF-8
FILE_PATTERNS = FILE_PATTERNS =
RECURSIVE = NO RECURSIVE = NO

View File

@ -14,7 +14,7 @@ namespace dmlc {
/*! /*!
* \brief interface of stream I/O for serialization * \brief interface of stream I/O for serialization
*/ */
class IStream { class Stream {
public: public:
/*! /*!
* \brief reads data from a stream * \brief reads data from a stream
@ -30,7 +30,7 @@ class IStream {
*/ */
virtual void Write(const void *ptr, size_t size) = 0; virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~IStream(void) {} virtual ~Stream(void) {}
/*! /*!
* \brief generic factory function * \brief generic factory function
* create an stream, the stream will close the underlying files * create an stream, the stream will close the underlying files
@ -38,8 +38,9 @@ class IStream {
* \param uri the uri of the input currently we support * \param uri the uri of the input currently we support
* hdfs://, s3://, and file:// by default file:// will be used * hdfs://, s3://, and file:// by default file:// will be used
* \param flag can be "w", "r", "a" * \param flag can be "w", "r", "a"
* \return a created stream
*/ */
static IStream *Create(const char *uri, const char* const flag); static Stream *Create(const char *uri, const char* const flag);
// helper functions to write/read different data structures // helper functions to write/read different data structures
/*! /*!
* \brief writes a vector * \brief writes a vector
@ -68,10 +69,10 @@ class IStream {
}; };
/*! \brief interface of i/o stream that support seek */ /*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream { class SeekStream: public Stream {
public: public:
// virtual destructor // virtual destructor
virtual ~ISeekStream(void) {} virtual ~SeekStream(void) {}
/*! \brief seek to certain position of the file */ /*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0; virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */ /*! \brief tell the position of the stream */
@ -81,18 +82,18 @@ class ISeekStream: public IStream {
}; };
/*! \brief interface for serializable objects */ /*! \brief interface for serializable objects */
class ISerializable { class Serializable {
public: public:
/*! /*!
* \brief load the model from a stream * \brief load the model from a stream
* \param fi stream where to load the model from * \param fi stream where to load the model from
*/ */
virtual void Load(IStream &fi) = 0; virtual void Load(Stream *fi) = 0;
/*! /*!
* \brief saves the model to a stream * \brief saves the model to a stream
* \param fo stream where to save the model to * \param fo stream where to save the model to
*/ */
virtual void Save(IStream &fo) const = 0; virtual void Save(Stream *fo) const = 0;
}; };
/*! /*!
@ -115,6 +116,7 @@ class InputSplit {
* \param uri the uri of the input, can contain hdfs prefix * \param uri the uri of the input, can contain hdfs prefix
* \param part_index the part id of current input * \param part_index the part id of current input
* \param num_parts total number of splits * \param num_parts total number of splits
* \return a created input split
*/ */
static InputSplit* Create(const char *uri, static InputSplit* Create(const char *uri,
unsigned part_index, unsigned part_index,
@ -123,7 +125,7 @@ class InputSplit {
// implementations of inline functions // implementations of inline functions
template<typename T> template<typename T>
inline void IStream::Write(const std::vector<T> &vec) { inline void Stream::Write(const std::vector<T> &vec) {
size_t sz = vec.size(); size_t sz = vec.size();
this->Write(&sz, sizeof(sz)); this->Write(&sz, sizeof(sz));
if (sz != 0) { if (sz != 0) {
@ -131,7 +133,7 @@ inline void IStream::Write(const std::vector<T> &vec) {
} }
} }
template<typename T> template<typename T>
inline bool IStream::Read(std::vector<T> *out_vec) { inline bool Stream::Read(std::vector<T> *out_vec) {
size_t sz; size_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false; if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz); out_vec->resize(sz);
@ -140,14 +142,14 @@ inline bool IStream::Read(std::vector<T> *out_vec) {
} }
return true; return true;
} }
inline void IStream::Write(const std::string &str) { inline void Stream::Write(const std::string &str) {
size_t sz = str.length(); size_t sz = str.length();
this->Write(&sz, sizeof(sz)); this->Write(&sz, sizeof(sz));
if (sz != 0) { if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz); this->Write(&str[0], sizeof(char) * sz);
} }
} }
inline bool IStream::Read(std::string *out_str) { inline bool Stream::Read(std::string *out_str) {
size_t sz; size_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false; if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz); out_str->resize(sz);

View File

@ -16,7 +16,7 @@
#if __cplusplus >= 201103L #if __cplusplus >= 201103L
#include <functional> #include <functional>
#endif // C++11 #endif // C++11
// contains definition of ISerializable // contains definition of Serializable
#include "./rabit_serializable.h" #include "./rabit_serializable.h"
// engine definition of rabit, defines internal implementation // engine definition of rabit, defines internal implementation
// to use rabit interface, there is no need to read engine.h // to use rabit interface, there is no need to read engine.h
@ -183,8 +183,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
inline int LoadCheckPoint(ISerializable *global_model, inline int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model = NULL); Serializable *local_model = NULL);
/*! /*!
* \brief checkpoints the model, meaning a stage of execution has finished. * \brief checkpoints the model, meaning a stage of execution has finished.
* every time we call check point, a version number will be increased by one * every time we call check point, a version number will be increased by one
@ -199,8 +199,8 @@ inline int LoadCheckPoint(ISerializable *global_model,
* So, only CheckPoint with the global_model if possible * So, only CheckPoint with the global_model if possible
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
inline void CheckPoint(const ISerializable *global_model, inline void CheckPoint(const Serializable *global_model,
const ISerializable *local_model = NULL); const Serializable *local_model = NULL);
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
@ -222,7 +222,7 @@ inline void CheckPoint(const ISerializable *global_model,
* is the same in every node * is the same in every node
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
inline void LazyCheckPoint(const ISerializable *global_model); inline void LazyCheckPoint(const Serializable *global_model);
/*! /*!
* \return version number of the current stored model, * \return version number of the current stored model,
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far

View File

@ -94,8 +94,8 @@ class IEngine {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(ISerializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model = NULL) = 0; Serializable *local_model = NULL) = 0;
/*! /*!
* \brief checkpoints the model, meaning a stage of execution was finished * \brief checkpoints the model, meaning a stage of execution was finished
* every time we call check point, a version number increases by ones * every time we call check point, a version number increases by ones
@ -112,8 +112,8 @@ class IEngine {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const ISerializable *local_model = NULL) = 0; const Serializable *local_model = NULL) = 0;
/*! /*!
* \brief This function can be used to replace CheckPoint for global_model only, * \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met (see detailed explanation). * when certain condition is met (see detailed explanation).
@ -134,7 +134,7 @@ class IEngine {
* is the same in every node * is the same in every node
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
virtual void LazyCheckPoint(const ISerializable *global_model) = 0; virtual void LazyCheckPoint(const Serializable *global_model) = 0;
/*! /*!
* \return version number of the current stored model, * \return version number of the current stored model,
* which means how many calls to CheckPoint we made so far * which means how many calls to CheckPoint we made so far

View File

@ -16,10 +16,10 @@
namespace rabit { namespace rabit {
namespace utils { namespace utils {
/*! \brief re-use definition of dmlc::ISeekStream */ /*! \brief re-use definition of dmlc::SeekStream */
typedef dmlc::ISeekStream ISeekStream; typedef dmlc::SeekStream SeekStream;
/*! \brief fixed size memory buffer */ /*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream { struct MemoryFixSizeBuffer : public SeekStream {
public: public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), : p_buffer_(reinterpret_cast<char*>(p_buffer)),
@ -61,7 +61,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
}; // class MemoryFixSizeBuffer }; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */ /*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream { struct MemoryBufferStream : public SeekStream {
public: public:
explicit MemoryBufferStream(std::string *p_buffer) explicit MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) { : p_buffer_(p_buffer) {

View File

@ -178,17 +178,17 @@ inline void TrackerPrintf(const char *fmt, ...) {
} }
#endif #endif
// load latest check point // load latest check point
inline int LoadCheckPoint(ISerializable *global_model, inline int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model) { Serializable *local_model) {
return engine::GetEngine()->LoadCheckPoint(global_model, local_model); return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
} }
// checkpoint the model, meaning we finished a stage of execution // checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const ISerializable *global_model, inline void CheckPoint(const Serializable *global_model,
const ISerializable *local_model) { const Serializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model); engine::GetEngine()->CheckPoint(global_model, local_model);
} }
// lazy checkpoint the model, only remember the pointer to global_model // lazy checkpoint the model, only remember the pointer to global_model
inline void LazyCheckPoint(const ISerializable *global_model) { inline void LazyCheckPoint(const Serializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model); engine::GetEngine()->LazyCheckPoint(global_model);
} }
// return the version number of currently stored model // return the version number of currently stored model

View File

@ -14,14 +14,14 @@
namespace rabit { namespace rabit {
/*! /*!
* \brief defines stream used in rabit * \brief defines stream used in rabit
* see definition of IStream in dmlc/io.h * see definition of Stream in dmlc/io.h
*/ */
typedef dmlc::IStream IStream; typedef dmlc::Stream Stream;
/*! /*!
* \brief defines serializable objects used in rabit * \brief defines serializable objects used in rabit
* see definition of ISerializable in dmlc/io.h * see definition of Serializable in dmlc/io.h
*/ */
typedef dmlc::ISerializable ISerializable; typedef dmlc::Serializable Serializable;
} // namespace rabit } // namespace rabit
#endif // RABIT_RABIT_SERIALIZABLE_H_ #endif // RABIT_RABIT_SERIALIZABLE_H_

View File

@ -33,9 +33,9 @@ static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64 } // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */ /*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream { class Base64InStream: public Stream {
public: public:
explicit Base64InStream(IStream *fs) : reader_(256) { explicit Base64InStream(Stream *fs) : reader_(256) {
reader_.set_stream(fs); reader_.set_stream(fs);
num_prev = 0; tmp_ch = 0; num_prev = 0; tmp_ch = 0;
} }
@ -147,9 +147,9 @@ class Base64InStream: public IStream {
static const bool kStrictCheck = false; static const bool kStrictCheck = false;
}; };
/*! \brief the stream that write to base64, note we take from file pointers */ /*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream { class Base64OutStream: public Stream {
public: public:
explicit Base64OutStream(IStream *fp) : fp(fp) { explicit Base64OutStream(Stream *fp) : fp(fp) {
buf_top = 0; buf_top = 0;
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
@ -198,7 +198,7 @@ class Base64OutStream: public IStream {
} }
private: private:
IStream *fp; Stream *fp;
int buf_top; int buf_top;
unsigned char buf[4]; unsigned char buf[4];
std::string out_buf; std::string out_buf;

View File

@ -20,7 +20,7 @@ class StreamBufferReader {
/*! /*!
* \brief set input stream * \brief set input stream
*/ */
inline void set_stream(IStream *stream) { inline void set_stream(Stream *stream) {
stream_ = stream; stream_ = stream;
read_len_ = read_ptr_ = 1; read_len_ = read_ptr_ = 1;
} }
@ -45,7 +45,7 @@ class StreamBufferReader {
private: private:
/*! \brief the underlying stream */ /*! \brief the underlying stream */
IStream *stream_; Stream *stream_;
/*! \brief buffer to hold data */ /*! \brief buffer to hold data */
std::string buffer_; std::string buffer_;
/*! \brief length of valid data in buffer */ /*! \brief length of valid data in buffer */

View File

@ -15,7 +15,7 @@
namespace rabit { namespace rabit {
namespace io { namespace io {
/*! \brief implementation of file i/o stream */ /*! \brief implementation of file i/o stream */
class FileStream : public utils::ISeekStream { class FileStream : public utils::SeekStream {
public: public:
explicit FileStream(const char *fname, const char *mode) explicit FileStream(const char *fname, const char *mode)
: use_stdio(false) { : use_stdio(false) {
@ -84,7 +84,7 @@ class FileProvider : public LineSplitter::IFileProvider {
} }
// destrucor // destrucor
virtual ~FileProvider(void) {} virtual ~FileProvider(void) {}
virtual utils::ISeekStream *Open(size_t file_index) { virtual utils::SeekStream *Open(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new FileStream(fnames_[file_index].c_str(), "rb"); return new FileStream(fnames_[file_index].c_str(), "rb");
} }

View File

@ -16,7 +16,7 @@
/*! \brief io interface */ /*! \brief io interface */
namespace rabit { namespace rabit {
namespace io { namespace io {
class HDFSStream : public ISeekStream { class HDFSStream : public SeekStream {
public: public:
HDFSStream(hdfsFS fs, HDFSStream(hdfsFS fs,
const char *fname, const char *fname,
@ -147,7 +147,7 @@ class HDFSProvider : public LineSplitter::IFileProvider {
virtual const std::vector<size_t> &FileSize(void) const { virtual const std::vector<size_t> &FileSize(void) const {
return fsize_; return fsize_;
} }
virtual ISeekStream *Open(size_t file_index) { virtual SeekStream *Open(size_t file_index) {
utils::Assert(file_index < fnames_.size(), "file index exceed bound"); utils::Assert(file_index < fnames_.size(), "file index exceed bound");
return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false); return new HDFSStream(fs_, fnames_[file_index].c_str(), "r", false);
} }

View File

@ -50,7 +50,7 @@ inline InputSplit *CreateInputSplit(const char *uri,
} }
template<typename TStream> template<typename TStream>
class StreamAdapter : public IStream { class StreamAdapter : public Stream {
public: public:
explicit StreamAdapter(TStream *stream) explicit StreamAdapter(TStream *stream)
: stream_(stream) { : stream_(stream) {
@ -75,9 +75,9 @@ class StreamAdapter : public IStream {
* \param uri the uri of the input, can contain hdfs prefix * \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write * \param mode can be 'w' or 'r' for read or write
*/ */
inline IStream *CreateStream(const char *uri, const char *mode) { inline Stream *CreateStream(const char *uri, const char *mode) {
#if RABIT_USE_WORMHOLE #if RABIT_USE_WORMHOLE
return new StreamAdapter<dmlc::IStream>(dmlc::IStream::Create(uri, mode)); return new StreamAdapter<dmlc::Stream>(dmlc::Stream::Create(uri, mode));
#else #else
using namespace std; using namespace std;
if (!strncmp(uri, "file://", 7)) { if (!strncmp(uri, "file://", 7)) {

View File

@ -26,12 +26,12 @@ namespace rabit {
* \brief namespace to handle input split and filesystem interfacing * \brief namespace to handle input split and filesystem interfacing
*/ */
namespace io { namespace io {
/*! \brief reused ISeekStream's definition */ /*! \brief reused SeekStream's definition */
#if RABIT_USE_WORMHOLE #if RABIT_USE_WORMHOLE
typedef dmlc::ISeekStream ISeekStream; typedef dmlc::SeekStream SeekStream;
typedef dmlc::InputSplit InputSplit; typedef dmlc::InputSplit InputSplit;
#else #else
typedef utils::ISeekStream ISeekStream; typedef utils::SeekStream SeekStream;
/*! /*!
* \brief user facing input split helper, * \brief user facing input split helper,
* can be used to get the partition of data used by current node * can be used to get the partition of data used by current node
@ -65,7 +65,7 @@ inline InputSplit *CreateInputSplit(const char *uri,
* \param uri the uri of the input, can contain hdfs prefix * \param uri the uri of the input, can contain hdfs prefix
* \param mode can be 'w' or 'r' for read or write * \param mode can be 'w' or 'r' for read or write
*/ */
inline IStream *CreateStream(const char *uri, const char *mode); inline Stream *CreateStream(const char *uri, const char *mode);
} // namespace io } // namespace io
} // namespace rabit } // namespace rabit

View File

@ -26,7 +26,7 @@ class LineSplitter : public InputSplit {
* \return the corresponding seek stream at head of the stream * \return the corresponding seek stream at head of the stream
* the seek stream's resource can be freed by calling delete * the seek stream's resource can be freed by calling delete
*/ */
virtual ISeekStream *Open(size_t file_index) = 0; virtual SeekStream *Open(size_t file_index) = 0;
/*! /*!
* \return const reference to size of each files * \return const reference to size of each files
*/ */
@ -142,7 +142,7 @@ class LineSplitter : public InputSplit {
/*! \brief FileProvider */ /*! \brief FileProvider */
IFileProvider *provider_; IFileProvider *provider_;
/*! \brief current input stream */ /*! \brief current input stream */
utils::ISeekStream *fs_; utils::SeekStream *fs_;
/*! \brief file pointer of which file to read on */ /*! \brief file pointer of which file to read on */
size_t file_ptr_; size_t file_ptr_;
/*! \brief file pointer where the end of file lies */ /*! \brief file pointer where the end of file lies */

View File

@ -7,22 +7,54 @@
using namespace rabit; using namespace rabit;
// simple dense matrix, mshadow or Eigen matrix was better
// this was was OK
struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
this->nrow = nrow;
this->ncol = ncol;
data.resize(nrow * ncol);
std::fill(data.begin(), data.end(), v);
}
inline float *operator[](size_t i) {
return &data[0] + i * ncol;
}
inline const float *operator[](size_t i) const {
return &data[0] + i * ncol;
}
inline void Print(utils::Stream *fo) {
for (size_t i = 0; i < data.size(); ++i) {
std::ostringstream ss;
ss << data[i];
if ((i+1) % ncol == 0) {
ss << '\n';
} else {
ss << ' ';
}
}
std::string s = ss.str();
}
// number of data
size_t nrow, ncol;
std::vector<float> data;
};
// kmeans model // kmeans model
class Model : public rabit::ISerializable { class Model : public rabit::Serializable {
public: public:
// matrix of centroids // matrix of centroids
Matrix centroids; Matrix centroids;
// load from stream // load from stream
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&centroids.nrow, sizeof(centroids.nrow)); fi->Read(&centroids.nrow, sizeof(centroids.nrow));
fi.Read(&centroids.ncol, sizeof(centroids.ncol)); fi->Read(&centroids.ncol, sizeof(centroids.ncol));
fi.Read(&centroids.data); fi->Read(&centroids.data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(&centroids.nrow, sizeof(centroids.nrow)); fo->Write(&centroids.nrow, sizeof(centroids.nrow));
fo.Write(&centroids.ncol, sizeof(centroids.ncol)); fo->Write(&centroids.ncol, sizeof(centroids.ncol));
fo.Write(centroids.data); fo->Write(centroids.data);
} }
virtual void InitModel(unsigned num_cluster, unsigned feat_dim) { virtual void InitModel(unsigned num_cluster, unsigned feat_dim) {
centroids.Init(num_cluster, feat_dim); centroids.Init(num_cluster, feat_dim);
@ -153,7 +185,7 @@ int main(int argc, char *argv[]) {
} }
} }
model.Normalize(); model.Normalize();
rabit::CheckPoint(&model); rabit::LazyCheckPoint(&model);
} }
// output the model file to somewhere // output the model file to somewhere
if (rabit::GetRank() == 0) { if (rabit::GetRank() == 0) {

View File

@ -75,7 +75,7 @@ class LinearObjFunction : public solver::IObjFunction<float> {
printf("Finishing writing to %s\n", name_pred.c_str()); printf("Finishing writing to %s\n", name_pred.c_str());
} }
inline void LoadModel(const char *fname) { inline void LoadModel(const char *fname) {
IStream *fi = io::CreateStream(fname, "r"); Stream *fi = io::CreateStream(fname, "r");
std::string header; header.resize(4); std::string header; header.resize(4);
// check header for different binary encode // check header for different binary encode
// can be base64 or binary // can be base64 or binary
@ -84,9 +84,9 @@ class LinearObjFunction : public solver::IObjFunction<float> {
if (header == "bs64") { if (header == "bs64") {
io::Base64InStream bsin(fi); io::Base64InStream bsin(fi);
bsin.InitPosition(); bsin.InitPosition();
model.Load(bsin); model.Load(&bsin);
} else if (header == "binf") { } else if (header == "binf") {
model.Load(*fi); model.Load(fi);
} else { } else {
utils::Error("invalid model file"); utils::Error("invalid model file");
} }
@ -95,15 +95,15 @@ class LinearObjFunction : public solver::IObjFunction<float> {
inline void SaveModel(const char *fname, inline void SaveModel(const char *fname,
const float *wptr, const float *wptr,
bool save_base64 = false) { bool save_base64 = false) {
IStream *fo = io::CreateStream(fname, "w"); Stream *fo = io::CreateStream(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) { if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5); fo->Write("bs64\t", 5);
io::Base64OutStream bout(fo); io::Base64OutStream bout(fo);
model.Save(bout, wptr); model.Save(&bout, wptr);
bout.Finish('\n'); bout.Finish('\n');
} else { } else {
fo->Write("binf", 4); fo->Write("binf", 4);
model.Save(*fo, wptr); model.Save(fo, wptr);
} }
delete fo; delete fo;
} }
@ -128,11 +128,11 @@ class LinearObjFunction : public solver::IObjFunction<float> {
} }
} }
// load model // load model
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&model.param, sizeof(model.param)); fi->Read(&model.param, sizeof(model.param));
} }
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(&model.param, sizeof(model.param)); fo->Write(&model.param, sizeof(model.param));
} }
virtual double Eval(const float *weight, size_t size) { virtual double Eval(const float *weight, size_t size) {
if (nthread != 0) omp_set_num_threads(nthread); if (nthread != 0) omp_set_num_threads(nthread);

View File

@ -113,17 +113,17 @@ struct LinearModel {
if (weight != NULL) delete [] weight; if (weight != NULL) delete [] weight;
} }
// load model // load model
inline void Load(rabit::IStream &fi) { inline void Load(rabit::Stream *fi) {
fi.Read(&param, sizeof(param)); fi->Read(&param, sizeof(param));
if (weight == NULL) { if (weight == NULL) {
weight = new float[param.num_feature + 1]; weight = new float[param.num_feature + 1];
} }
fi.Read(weight, sizeof(float) * (param.num_feature + 1)); fi->Read(weight, sizeof(float) * (param.num_feature + 1));
} }
inline void Save(rabit::IStream &fo, const float *wptr = NULL) { inline void Save(rabit::Stream *fo, const float *wptr = NULL) {
fo.Write(&param, sizeof(param)); fo->Write(&param, sizeof(param));
if (wptr == NULL) wptr = weight; if (wptr == NULL) wptr = weight;
fo.Write(wptr, sizeof(float) * (param.num_feature + 1)); fo->Write(wptr, sizeof(float) * (param.num_feature + 1));
} }
inline float Predict(const SparseMat::Vector &v) const { inline float Predict(const SparseMat::Vector &v) const {
return param.Predict(weight, v); return param.Predict(weight, v);

View File

@ -19,7 +19,7 @@ namespace solver {
* to remember the state parameters that might need to remember * to remember the state parameters that might need to remember
*/ */
template<typename DType> template<typename DType>
class IObjFunction : public rabit::ISerializable { class IObjFunction : public rabit::Serializable {
public: public:
// destructor // destructor
virtual ~IObjFunction(void){} virtual ~IObjFunction(void){}
@ -463,7 +463,7 @@ class LBFGSSolver {
} }
} }
// global solver state // global solver state
struct GlobalState : public rabit::ISerializable { struct GlobalState : public rabit::Serializable {
public: public:
// memory size of L-BFGS // memory size of L-BFGS
size_t size_memory; size_t size_memory;
@ -514,28 +514,28 @@ class LBFGSSolver {
MapIndex(j, offset_, size_memory)]; MapIndex(j, offset_, size_memory)];
} }
// load the shift array // load the shift array
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&size_memory, sizeof(size_memory)); fi->Read(&size_memory, sizeof(size_memory));
fi.Read(&num_iteration, sizeof(num_iteration)); fi->Read(&num_iteration, sizeof(num_iteration));
fi.Read(&num_dim, sizeof(num_dim)); fi->Read(&num_dim, sizeof(num_dim));
fi.Read(&init_objval, sizeof(init_objval)); fi->Read(&init_objval, sizeof(init_objval));
fi.Read(&old_objval, sizeof(old_objval)); fi->Read(&old_objval, sizeof(old_objval));
fi.Read(&offset_, sizeof(offset_)); fi->Read(&offset_, sizeof(offset_));
fi.Read(&data); fi->Read(&data);
this->AllocSpace(); this->AllocSpace();
fi.Read(weight, sizeof(DType) * num_dim); fi->Read(weight, sizeof(DType) * num_dim);
obj->Load(fi); obj->Load(fi);
} }
// save the shift array // save the shift array
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(&size_memory, sizeof(size_memory)); fo->Write(&size_memory, sizeof(size_memory));
fo.Write(&num_iteration, sizeof(num_iteration)); fo->Write(&num_iteration, sizeof(num_iteration));
fo.Write(&num_dim, sizeof(num_dim)); fo->Write(&num_dim, sizeof(num_dim));
fo.Write(&init_objval, sizeof(init_objval)); fo->Write(&init_objval, sizeof(init_objval));
fo.Write(&old_objval, sizeof(old_objval)); fo->Write(&old_objval, sizeof(old_objval));
fo.Write(&offset_, sizeof(offset_)); fo->Write(&offset_, sizeof(offset_));
fo.Write(data); fo->Write(data);
fo.Write(weight, sizeof(DType) * num_dim); fo->Write(weight, sizeof(DType) * num_dim);
obj->Save(fo); obj->Save(fo);
} }
inline void Shift(void) { inline void Shift(void) {
@ -556,7 +556,7 @@ class LBFGSSolver {
} }
}; };
/*! \brief rolling array that carries history information */ /*! \brief rolling array that carries history information */
struct HistoryArray : public rabit::ISerializable { struct HistoryArray : public rabit::Serializable {
public: public:
HistoryArray(void) : dptr_(NULL) { HistoryArray(void) : dptr_(NULL) {
num_useful_ = 0; num_useful_ = 0;
@ -609,26 +609,26 @@ class LBFGSSolver {
num_useful_ = num_useful; num_useful_ = num_useful;
} }
// load the shift array // load the shift array
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&num_col_, sizeof(num_col_)); fi->Read(&num_col_, sizeof(num_col_));
fi.Read(&stride_, sizeof(stride_)); fi->Read(&stride_, sizeof(stride_));
fi.Read(&size_memory_, sizeof(size_memory_)); fi->Read(&size_memory_, sizeof(size_memory_));
fi.Read(&num_useful_, sizeof(num_useful_)); fi->Read(&num_useful_, sizeof(num_useful_));
this->Init(num_col_, size_memory_); this->Init(num_col_, size_memory_);
for (size_t i = 0; i < num_useful_; ++i) { for (size_t i = 0; i < num_useful_; ++i) {
fi.Read((*this)[i], num_col_ * sizeof(DType)); fi->Read((*this)[i], num_col_ * sizeof(DType));
fi.Read((*this)[i + size_memory_], num_col_ * sizeof(DType)); fi->Read((*this)[i + size_memory_], num_col_ * sizeof(DType));
} }
} }
// save the shift array // save the shift array
virtual void Save(rabit::IStream &fi) const { virtual void Save(rabit::Stream *fo) const {
fi.Write(&num_col_, sizeof(num_col_)); fo->Write(&num_col_, sizeof(num_col_));
fi.Write(&stride_, sizeof(stride_)); fo->Write(&stride_, sizeof(stride_));
fi.Write(&size_memory_, sizeof(size_memory_)); fo->Write(&size_memory_, sizeof(size_memory_));
fi.Write(&num_useful_, sizeof(num_useful_)); fo->Write(&num_useful_, sizeof(num_useful_));
for (size_t i = 0; i < num_useful_; ++i) { for (size_t i = 0; i < num_useful_; ++i) {
fi.Write((*this)[i], num_col_ * sizeof(DType)); fo->Write((*this)[i], num_col_ * sizeof(DType));
fi.Write((*this)[i + size_memory_], num_col_ * sizeof(DType)); fo->Write((*this)[i + size_memory_], num_col_ * sizeof(DType));
} }
} }

View File

@ -93,43 +93,6 @@ struct SparseMat {
std::vector<float> labels; std::vector<float> labels;
}; };
// dense matrix
struct Matrix {
inline void Init(size_t nrow, size_t ncol, float v = 0.0f) {
this->nrow = nrow;
this->ncol = ncol;
data.resize(nrow * ncol);
std::fill(data.begin(), data.end(), v);
}
inline float *operator[](size_t i) {
return &data[0] + i * ncol;
}
inline const float *operator[](size_t i) const {
return &data[0] + i * ncol;
}
inline void Print(const char *fname) {
FILE *fo;
if (!strcmp(fname, "stdout")) {
fo = stdout;
} else {
fo = utils::FopenCheck(fname, "w");
}
for (size_t i = 0; i < data.size(); ++i) {
fprintf(fo, "%g", data[i]);
if ((i+1) % ncol == 0) {
fprintf(fo, "\n");
} else {
fprintf(fo, " ");
}
}
// close the filed
if (fo != stdout) fclose(fo);
}
// number of data
size_t nrow, ncol;
std::vector<float> data;
};
/*!\brief computes a random number modulo the value */ /*!\brief computes a random number modulo the value */
inline int Random(int value) { inline int Random(int value) {
return rand() % value; return rand() % value;

View File

@ -126,8 +126,8 @@ class AllreduceBase : public IEngine {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(ISerializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model = NULL) { Serializable *local_model = NULL) {
return 0; return 0;
} }
/*! /*!
@ -146,8 +146,8 @@ class AllreduceBase : public IEngine {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const ISerializable *local_model = NULL) { const Serializable *local_model = NULL) {
version_number += 1; version_number += 1;
} }
/*! /*!
@ -170,7 +170,7 @@ class AllreduceBase : public IEngine {
* is the same in all nodes * is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1; version_number += 1;
} }
/*! /*!

View File

@ -5,8 +5,8 @@
* *
* \author Ignacio Cano, Tianqi Chen * \author Ignacio Cano, Tianqi Chen
*/ */
#ifndef RABIT_ALLREDUCE_MOCK_H #ifndef RABIT_ALLREDUCE_MOCK_H_
#define RABIT_ALLREDUCE_MOCK_H #define RABIT_ALLREDUCE_MOCK_H_
#include <vector> #include <vector>
#include <map> #include <map>
#include <sstream> #include <sstream>
@ -58,8 +58,8 @@ class AllreduceMock : public AllreduceRobust {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
} }
virtual int LoadCheckPoint(ISerializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model) { Serializable *local_model) {
tsum_allreduce = 0.0; tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime(); time_checkpoint = utils::GetTime();
if (force_local == 0) { if (force_local == 0) {
@ -70,8 +70,8 @@ class AllreduceMock : public AllreduceRobust {
return AllreduceRobust::LoadCheckPoint(&dum, &com); return AllreduceRobust::LoadCheckPoint(&dum, &com);
} }
} }
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const ISerializable *local_model) { const Serializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
double tstart = utils::GetTime(); double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint; double tbet_chkpt = tstart - time_checkpoint;
@ -96,7 +96,7 @@ class AllreduceMock : public AllreduceRobust {
tsum_allreduce = 0.0; tsum_allreduce = 0.0;
} }
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const Serializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint"); this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model); AllreduceRobust::LazyCheckPoint(global_model);
} }
@ -110,28 +110,28 @@ class AllreduceMock : public AllreduceRobust {
double time_checkpoint; double time_checkpoint;
private: private:
struct DummySerializer : public ISerializable { struct DummySerializer : public Serializable {
virtual void Load(IStream &fi) { virtual void Load(Stream *fi) {
} }
virtual void Save(IStream &fo) const { virtual void Save(Stream *fo) const {
} }
}; };
struct ComboSerializer : public ISerializable { struct ComboSerializer : public Serializable {
ISerializable *lhs; Serializable *lhs;
ISerializable *rhs; Serializable *rhs;
const ISerializable *c_lhs; const Serializable *c_lhs;
const ISerializable *c_rhs; const Serializable *c_rhs;
ComboSerializer(ISerializable *lhs, ISerializable *rhs) ComboSerializer(Serializable *lhs, Serializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) { : lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
} }
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs) ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) { : lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
} }
virtual void Load(IStream &fi) { virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi); if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi); if (rhs != NULL) rhs->Load(fi);
} }
virtual void Save(IStream &fo) const { virtual void Save(Stream *fo) const {
if (c_lhs != NULL) c_lhs->Save(fo); if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo); if (c_rhs != NULL) c_rhs->Save(fo);
} }
@ -173,4 +173,4 @@ class AllreduceMock : public AllreduceRobust {
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H #endif // RABIT_ALLREDUCE_MOCK_H_

View File

@ -158,8 +158,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
ISerializable *local_model) { Serializable *local_model) {
// skip action in single node // skip action in single node
if (world_size == 1) return 0; if (world_size == 1) return 0;
this->LocalModelCheck(local_model != NULL); this->LocalModelCheck(local_model != NULL);
@ -175,7 +175,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// load in local model // load in local model
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]), utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
local_rptr[local_chkpt_version][1]); local_rptr[local_chkpt_version][1]);
local_model->Load(fs); local_model->Load(&fs);
} else { } else {
utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal); utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
} }
@ -189,7 +189,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
} else { } else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number"); "read in version number");
global_model->Load(fs); global_model->Load(&fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1, utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal); "local model inconsistent, nlocal=%d", nlocal);
} }
@ -241,8 +241,8 @@ void AllreduceRobust::LocalModelCheck(bool with_local) {
* *
* \sa CheckPoint, LazyCheckPoint * \sa CheckPoint, LazyCheckPoint
*/ */
void AllreduceRobust::CheckPoint_(const ISerializable *global_model, void AllreduceRobust::CheckPoint_(const Serializable *global_model,
const ISerializable *local_model, const Serializable *local_model,
bool lazy_checkpt) { bool lazy_checkpt) {
// never do check point in single machine mode // never do check point in single machine mode
if (world_size == 1) { if (world_size == 1) {
@ -261,7 +261,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
local_chkpt[new_version].clear(); local_chkpt[new_version].clear();
utils::MemoryBufferStream fs(&local_chkpt[new_version]); utils::MemoryBufferStream fs(&local_chkpt[new_version]);
if (local_model != NULL) { if (local_model != NULL) {
local_model->Save(fs); local_model->Save(&fs);
} }
local_rptr[new_version].clear(); local_rptr[new_version].clear();
local_rptr[new_version].push_back(0); local_rptr[new_version].push_back(0);
@ -287,7 +287,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
global_checkpoint.resize(0); global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint); utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number)); fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs); global_model->Save(&fs);
global_lazycheck = NULL; global_lazycheck = NULL;
} }
// reset result buffer // reset result buffer
@ -748,7 +748,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
global_checkpoint.resize(0); global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint); utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number)); fs.Write(&version_number, sizeof(version_number));
global_lazycheck->Save(fs); global_lazycheck->Save(&fs);
global_lazycheck = NULL; global_lazycheck = NULL;
} }
// recover global checkpoint // recover global checkpoint

View File

@ -80,8 +80,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa CheckPoint, VersionNumber * \sa CheckPoint, VersionNumber
*/ */
virtual int LoadCheckPoint(ISerializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model = NULL); Serializable *local_model = NULL);
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
@ -98,8 +98,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa LoadCheckPoint, VersionNumber * \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const ISerializable *local_model = NULL) { const Serializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false); this->CheckPoint_(global_model, local_model, false);
} }
/*! /*!
@ -122,7 +122,7 @@ class AllreduceRobust : public AllreduceBase {
* is the same in all nodes * is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber * \sa LoadCheckPoint, CheckPoint, VersionNumber
*/ */
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const Serializable *global_model) {
this->CheckPoint_(global_model, NULL, true); this->CheckPoint_(global_model, NULL, true);
} }
/*! /*!
@ -318,8 +318,8 @@ class AllreduceRobust : public AllreduceBase {
* *
* \sa CheckPoint, LazyCheckPoint * \sa CheckPoint, LazyCheckPoint
*/ */
void CheckPoint_(const ISerializable *global_model, void CheckPoint_(const Serializable *global_model,
const ISerializable *local_model, const Serializable *local_model,
bool lazy_checkpt); bool lazy_checkpt);
/*! /*!
* \brief reset the all the existing links by sending Out-of-Band message marker * \brief reset the all the existing links by sending Out-of-Band message marker
@ -521,7 +521,7 @@ o * the input state must exactly one saved state(local state of current node)
// last check point global model // last check point global model
std::string global_checkpoint; std::string global_checkpoint;
// lazy checkpoint of global model // lazy checkpoint of global model
const ISerializable *global_lazycheck; const Serializable *global_lazycheck;
// number of replica for local state/model // number of replica for local state/model
int num_local_replica; int num_local_replica;
// number of default local replica // number of default local replica

View File

@ -37,15 +37,15 @@ class MPIEngine : public IEngine {
virtual void InitAfterException(void) { virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant"); utils::Error("MPI is not fault tolerant");
} }
virtual int LoadCheckPoint(ISerializable *global_model, virtual int LoadCheckPoint(Serializable *global_model,
ISerializable *local_model = NULL) { Serializable *local_model = NULL) {
return 0; return 0;
} }
virtual void CheckPoint(const ISerializable *global_model, virtual void CheckPoint(const Serializable *global_model,
const ISerializable *local_model = NULL) { const Serializable *local_model = NULL) {
version_number += 1; version_number += 1;
} }
virtual void LazyCheckPoint(const ISerializable *global_model) { virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1; version_number += 1;
} }
virtual int VersionNumber(void) const { virtual int VersionNumber(void) const {

View File

@ -8,17 +8,17 @@
using namespace rabit; using namespace rabit;
// dummy model // dummy model
class Model : public rabit::ISerializable { class Model : public rabit::Serializable {
public: public:
// iterations // iterations
std::vector<float> data; std::vector<float> data;
// load from stream // load from stream
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&data); fi->Read(&data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(data); fo->Write(data);
} }
virtual void InitModel(size_t n) { virtual void InitModel(size_t n) {
data.clear(); data.clear();

View File

@ -9,17 +9,17 @@
using namespace rabit; using namespace rabit;
// dummy model // dummy model
class Model : public rabit::ISerializable { class Model : public rabit::Serializable {
public: public:
// iterations // iterations
std::vector<float> data; std::vector<float> data;
// load from stream // load from stream
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&data); fi->Read(&data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(data); fo->Write(data);
} }
virtual void InitModel(size_t n, float v) { virtual void InitModel(size_t n, float v) {
data.clear(); data.clear();

View File

@ -8,17 +8,17 @@
using namespace rabit; using namespace rabit;
// dummy model // dummy model
class Model : public rabit::ISerializable { class Model : public rabit::Serializable {
public: public:
// iterations // iterations
std::vector<float> data; std::vector<float> data;
// load from stream // load from stream
virtual void Load(rabit::IStream &fi) { virtual void Load(rabit::Stream *fi) {
fi.Read(&data); fi->Read(&data);
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::IStream &fo) const { virtual void Save(rabit::Stream *fo) const {
fo.Write(data); fo->Write(data);
} }
virtual void InitModel(size_t n) { virtual void InitModel(size_t n) {
data.clear(); data.clear();

View File

@ -119,38 +119,38 @@ inline void Allreduce(void *sendrecvbuf,
// temporal memory for global and local model // temporal memory for global and local model
std::string global_buffer, local_buffer; std::string global_buffer, local_buffer;
// wrapper for serialization // wrapper for serialization
struct ReadWrapper : public ISerializable { struct ReadWrapper : public Serializable {
std::string *p_str; std::string *p_str;
explicit ReadWrapper(std::string *p_str) explicit ReadWrapper(std::string *p_str)
: p_str(p_str) {} : p_str(p_str) {}
virtual void Load(IStream &fi) { virtual void Load(Stream *fi) {
uint64_t sz; uint64_t sz;
utils::Assert(fi.Read(&sz, sizeof(sz)) != 0, utils::Assert(fi->Read(&sz, sizeof(sz)) != 0,
"Read pickle string"); "Read pickle string");
p_str->resize(sz); p_str->resize(sz);
if (sz != 0) { if (sz != 0) {
utils::Assert(fi.Read(&(*p_str)[0], sizeof(char) * sz) != 0, utils::Assert(fi->Read(&(*p_str)[0], sizeof(char) * sz) != 0,
"Read pickle string"); "Read pickle string");
} }
} }
virtual void Save(IStream &fo) const { virtual void Save(Stream *fo) const {
utils::Error("not implemented"); utils::Error("not implemented");
} }
}; };
struct WriteWrapper : public ISerializable { struct WriteWrapper : public Serializable {
const char *data; const char *data;
size_t length; size_t length;
explicit WriteWrapper(const char *data, explicit WriteWrapper(const char *data,
size_t length) size_t length)
: data(data), length(length) { : data(data), length(length) {
} }
virtual void Load(IStream &fi) { virtual void Load(Stream *fi) {
utils::Error("not implemented"); utils::Error("not implemented");
} }
virtual void Save(IStream &fo) const { virtual void Save(Stream *fo) const {
uint64_t sz = static_cast<uint16_t>(length); uint64_t sz = static_cast<uint16_t>(length);
fo.Write(&sz, sizeof(sz)); fo->Write(&sz, sizeof(sz));
fo.Write(data, length * sizeof(char)); fo->Write(data, length * sizeof(char));
} }
}; };
} // namespace wrapper } // namespace wrapper