Merge commit '3d11f56880521c1d45504c965ae12886e9b72ace'
This commit is contained in:
@@ -14,7 +14,7 @@ namespace dmlc {
|
||||
/*!
|
||||
* \brief interface of stream I/O for serialization
|
||||
*/
|
||||
class IStream {
|
||||
class Stream {
|
||||
public:
|
||||
/*!
|
||||
* \brief reads data from a stream
|
||||
@@ -30,7 +30,7 @@ class IStream {
|
||||
*/
|
||||
virtual void Write(const void *ptr, size_t size) = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IStream(void) {}
|
||||
virtual ~Stream(void) {}
|
||||
/*!
|
||||
* \brief generic factory function
|
||||
* create an stream, the stream will close the underlying files
|
||||
@@ -38,8 +38,9 @@ class IStream {
|
||||
* \param uri the uri of the input currently we support
|
||||
* hdfs://, s3://, and file:// by default file:// will be used
|
||||
* \param flag can be "w", "r", "a"
|
||||
* \return a created stream
|
||||
*/
|
||||
static IStream *Create(const char *uri, const char* const flag);
|
||||
static Stream *Create(const char *uri, const char* const flag);
|
||||
// helper functions to write/read different data structures
|
||||
/*!
|
||||
* \brief writes a vector
|
||||
@@ -68,10 +69,10 @@ class IStream {
|
||||
};
|
||||
|
||||
/*! \brief interface of i/o stream that support seek */
|
||||
class ISeekStream: public IStream {
|
||||
class SeekStream: public Stream {
|
||||
public:
|
||||
// virtual destructor
|
||||
virtual ~ISeekStream(void) {}
|
||||
virtual ~SeekStream(void) {}
|
||||
/*! \brief seek to certain position of the file */
|
||||
virtual void Seek(size_t pos) = 0;
|
||||
/*! \brief tell the position of the stream */
|
||||
@@ -81,18 +82,18 @@ class ISeekStream: public IStream {
|
||||
};
|
||||
|
||||
/*! \brief interface for serializable objects */
|
||||
class ISerializable {
|
||||
class Serializable {
|
||||
public:
|
||||
/*!
|
||||
* \brief load the model from a stream
|
||||
* \param fi stream where to load the model from
|
||||
*/
|
||||
virtual void Load(IStream &fi) = 0;
|
||||
virtual void Load(Stream *fi) = 0;
|
||||
/*!
|
||||
* \brief saves the model to a stream
|
||||
* \param fo stream where to save the model to
|
||||
*/
|
||||
virtual void Save(IStream &fo) const = 0;
|
||||
virtual void Save(Stream *fo) const = 0;
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -115,6 +116,7 @@ class InputSplit {
|
||||
* \param uri the uri of the input, can contain hdfs prefix
|
||||
* \param part_index the part id of current input
|
||||
* \param num_parts total number of splits
|
||||
* \return a created input split
|
||||
*/
|
||||
static InputSplit* Create(const char *uri,
|
||||
unsigned part_index,
|
||||
@@ -123,7 +125,7 @@ class InputSplit {
|
||||
|
||||
// implementations of inline functions
|
||||
template<typename T>
|
||||
inline void IStream::Write(const std::vector<T> &vec) {
|
||||
inline void Stream::Write(const std::vector<T> &vec) {
|
||||
size_t sz = vec.size();
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
@@ -131,7 +133,7 @@ inline void IStream::Write(const std::vector<T> &vec) {
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
inline bool IStream::Read(std::vector<T> *out_vec) {
|
||||
inline bool Stream::Read(std::vector<T> *out_vec) {
|
||||
size_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_vec->resize(sz);
|
||||
@@ -140,14 +142,14 @@ inline bool IStream::Read(std::vector<T> *out_vec) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
inline void IStream::Write(const std::string &str) {
|
||||
inline void Stream::Write(const std::string &str) {
|
||||
size_t sz = str.length();
|
||||
this->Write(&sz, sizeof(sz));
|
||||
if (sz != 0) {
|
||||
this->Write(&str[0], sizeof(char) * sz);
|
||||
}
|
||||
}
|
||||
inline bool IStream::Read(std::string *out_str) {
|
||||
inline bool Stream::Read(std::string *out_str) {
|
||||
size_t sz;
|
||||
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
||||
out_str->resize(sz);
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#if __cplusplus >= 201103L
|
||||
#include <functional>
|
||||
#endif // C++11
|
||||
// contains definition of ISerializable
|
||||
// contains definition of Serializable
|
||||
#include "./rabit_serializable.h"
|
||||
// engine definition of rabit, defines internal implementation
|
||||
// to use rabit interface, there is no need to read engine.h
|
||||
@@ -183,8 +183,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL);
|
||||
inline int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoints the model, meaning a stage of execution has finished.
|
||||
* every time we call check point, a version number will be increased by one
|
||||
@@ -199,8 +199,8 @@ inline int LoadCheckPoint(ISerializable *global_model,
|
||||
* So, only CheckPoint with the global_model if possible
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
inline void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL);
|
||||
inline void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met (see detailed explanation).
|
||||
@@ -222,7 +222,7 @@ inline void CheckPoint(const ISerializable *global_model,
|
||||
* is the same in every node
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
inline void LazyCheckPoint(const ISerializable *global_model);
|
||||
inline void LazyCheckPoint(const Serializable *global_model);
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
@@ -94,8 +94,8 @@ class IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) = 0;
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief checkpoints the model, meaning a stage of execution was finished
|
||||
* every time we call check point, a version number increases by ones
|
||||
@@ -112,8 +112,8 @@ class IEngine {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) = 0;
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||
* when certain condition is met (see detailed explanation).
|
||||
@@ -134,7 +134,7 @@ class IEngine {
|
||||
* is the same in every node
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) = 0;
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) = 0;
|
||||
/*!
|
||||
* \return version number of the current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
|
||||
namespace rabit {
|
||||
namespace utils {
|
||||
/*! \brief re-use definition of dmlc::ISeekStream */
|
||||
typedef dmlc::ISeekStream ISeekStream;
|
||||
/*! \brief re-use definition of dmlc::SeekStream */
|
||||
typedef dmlc::SeekStream SeekStream;
|
||||
/*! \brief fixed size memory buffer */
|
||||
struct MemoryFixSizeBuffer : public ISeekStream {
|
||||
struct MemoryFixSizeBuffer : public SeekStream {
|
||||
public:
|
||||
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
||||
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
|
||||
@@ -61,7 +61,7 @@ struct MemoryFixSizeBuffer : public ISeekStream {
|
||||
}; // class MemoryFixSizeBuffer
|
||||
|
||||
/*! \brief a in memory buffer that can be read and write as stream interface */
|
||||
struct MemoryBufferStream : public ISeekStream {
|
||||
struct MemoryBufferStream : public SeekStream {
|
||||
public:
|
||||
explicit MemoryBufferStream(std::string *p_buffer)
|
||||
: p_buffer_(p_buffer) {
|
||||
|
||||
@@ -178,17 +178,17 @@ inline void TrackerPrintf(const char *fmt, ...) {
|
||||
}
|
||||
#endif
|
||||
// load latest check point
|
||||
inline int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
inline int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
return engine::GetEngine()->LoadCheckPoint(global_model, local_model);
|
||||
}
|
||||
// checkpoint the model, meaning we finished a stage of execution
|
||||
inline void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
inline void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) {
|
||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||
}
|
||||
// lazy checkpoint the model, only remember the pointer to global_model
|
||||
inline void LazyCheckPoint(const ISerializable *global_model) {
|
||||
inline void LazyCheckPoint(const Serializable *global_model) {
|
||||
engine::GetEngine()->LazyCheckPoint(global_model);
|
||||
}
|
||||
// return the version number of currently stored model
|
||||
|
||||
@@ -14,14 +14,14 @@
|
||||
namespace rabit {
|
||||
/*!
|
||||
* \brief defines stream used in rabit
|
||||
* see definition of IStream in dmlc/io.h
|
||||
* see definition of Stream in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::IStream IStream;
|
||||
typedef dmlc::Stream Stream;
|
||||
/*!
|
||||
* \brief defines serializable objects used in rabit
|
||||
* see definition of ISerializable in dmlc/io.h
|
||||
* see definition of Serializable in dmlc/io.h
|
||||
*/
|
||||
typedef dmlc::ISerializable ISerializable;
|
||||
typedef dmlc::Serializable Serializable;
|
||||
|
||||
} // namespace rabit
|
||||
#endif // RABIT_RABIT_SERIALIZABLE_H_
|
||||
|
||||
Reference in New Issue
Block a user