remove I prefix from interface, serializable now takes in pointer

This commit is contained in:
tqchen
2015-04-08 15:25:58 -07:00
parent b15f6cd2ac
commit e95c96232a
27 changed files with 221 additions and 226 deletions

View File

@@ -126,8 +126,8 @@ class AllreduceBase : public IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
/*!
@@ -146,8 +146,8 @@ class AllreduceBase : public IEngine {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
/*!
@@ -170,7 +170,7 @@ class AllreduceBase : public IEngine {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
/*!

View File

@@ -5,8 +5,8 @@
*
* \author Ignacio Cano, Tianqi Chen
*/
#ifndef RABIT_ALLREDUCE_MOCK_H
#define RABIT_ALLREDUCE_MOCK_H
#ifndef RABIT_ALLREDUCE_MOCK_H_
#define RABIT_ALLREDUCE_MOCK_H_
#include <vector>
#include <map>
#include <sstream>
@@ -58,8 +58,8 @@ class AllreduceMock : public AllreduceRobust {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
tsum_allreduce = 0.0;
time_checkpoint = utils::GetTime();
if (force_local == 0) {
@@ -70,8 +70,8 @@ class AllreduceMock : public AllreduceRobust {
return AllreduceRobust::LoadCheckPoint(&dum, &com);
}
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
double tstart = utils::GetTime();
double tbet_chkpt = tstart - time_checkpoint;
@@ -96,7 +96,7 @@ class AllreduceMock : public AllreduceRobust {
tsum_allreduce = 0.0;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
virtual void LazyCheckPoint(const Serializable *global_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
AllreduceRobust::LazyCheckPoint(global_model);
}
@@ -110,28 +110,28 @@ class AllreduceMock : public AllreduceRobust {
double time_checkpoint;
private:
struct DummySerializer : public ISerializable {
virtual void Load(IStream &fi) {
struct DummySerializer : public Serializable {
virtual void Load(Stream *fi) {
}
virtual void Save(IStream &fo) const {
virtual void Save(Stream *fo) const {
}
};
struct ComboSerializer : public ISerializable {
ISerializable *lhs;
ISerializable *rhs;
const ISerializable *c_lhs;
const ISerializable *c_rhs;
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
struct ComboSerializer : public Serializable {
Serializable *lhs;
Serializable *rhs;
const Serializable *c_lhs;
const Serializable *c_rhs;
ComboSerializer(Serializable *lhs, Serializable *rhs)
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
}
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
}
virtual void Load(IStream &fi) {
virtual void Load(Stream *fi) {
if (lhs != NULL) lhs->Load(fi);
if (rhs != NULL) rhs->Load(fi);
}
virtual void Save(IStream &fo) const {
virtual void Save(Stream *fo) const {
if (c_lhs != NULL) c_lhs->Save(fo);
if (c_rhs != NULL) c_rhs->Save(fo);
}
@@ -173,4 +173,4 @@ class AllreduceMock : public AllreduceRobust {
};
} // namespace engine
} // namespace rabit
#endif // RABIT_ALLREDUCE_MOCK_H
#endif // RABIT_ALLREDUCE_MOCK_H_

View File

@@ -158,8 +158,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
*
* \sa CheckPoint, VersionNumber
*/
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
Serializable *local_model) {
// skip action in single node
if (world_size == 1) return 0;
this->LocalModelCheck(local_model != NULL);
@@ -175,7 +175,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
// load in local model
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
local_rptr[local_chkpt_version][1]);
local_model->Load(fs);
local_model->Load(&fs);
} else {
utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
}
@@ -189,7 +189,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
} else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
"read in version number");
global_model->Load(fs);
global_model->Load(&fs);
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
"local model inconsistent, nlocal=%d", nlocal);
}
@@ -241,8 +241,8 @@ void AllreduceRobust::LocalModelCheck(bool with_local) {
*
* \sa CheckPoint, LazyCheckPoint
*/
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
void AllreduceRobust::CheckPoint_(const Serializable *global_model,
const Serializable *local_model,
bool lazy_checkpt) {
// never do check point in single machine mode
if (world_size == 1) {
@@ -261,7 +261,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
local_chkpt[new_version].clear();
utils::MemoryBufferStream fs(&local_chkpt[new_version]);
if (local_model != NULL) {
local_model->Save(fs);
local_model->Save(&fs);
}
local_rptr[new_version].clear();
local_rptr[new_version].push_back(0);
@@ -287,7 +287,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
global_model->Save(&fs);
global_lazycheck = NULL;
}
// reset result buffer
@@ -748,7 +748,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_lazycheck->Save(fs);
global_lazycheck->Save(&fs);
global_lazycheck = NULL;
}
// recover global checkpoint

View File

@@ -80,8 +80,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL);
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
@@ -98,8 +98,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
}
/*!
@@ -122,7 +122,7 @@ class AllreduceRobust : public AllreduceBase {
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
virtual void LazyCheckPoint(const Serializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
}
/*!
@@ -318,8 +318,8 @@ class AllreduceRobust : public AllreduceBase {
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
void CheckPoint_(const Serializable *global_model,
const Serializable *local_model,
bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
@@ -521,7 +521,7 @@ o * the input state must exactly one saved state(local state of current node)
// last check point global model
std::string global_checkpoint;
// lazy checkpoint of global model
const ISerializable *global_lazycheck;
const Serializable *global_lazycheck;
// number of replica for local state/model
int num_local_replica;
// number of default local replica

View File

@@ -37,15 +37,15 @@ class MPIEngine : public IEngine {
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model = NULL) {
virtual int LoadCheckPoint(Serializable *global_model,
Serializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) {
virtual void CheckPoint(const Serializable *global_model,
const Serializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
virtual void LazyCheckPoint(const Serializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {