remove I prefix from interface, serializable now takes in pointer
This commit is contained in:
@@ -126,8 +126,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
@@ -146,8 +146,8 @@ class AllreduceBase : public IEngine {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
@@ -170,7 +170,7 @@ class AllreduceBase : public IEngine {
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
*
|
||||
* \author Ignacio Cano, Tianqi Chen
|
||||
*/
|
||||
#ifndef RABIT_ALLREDUCE_MOCK_H
|
||||
#define RABIT_ALLREDUCE_MOCK_H
|
||||
#ifndef RABIT_ALLREDUCE_MOCK_H_
|
||||
#define RABIT_ALLREDUCE_MOCK_H_
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
@@ -58,8 +58,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
tsum_allreduce = 0.0;
|
||||
time_checkpoint = utils::GetTime();
|
||||
if (force_local == 0) {
|
||||
@@ -70,8 +70,8 @@ class AllreduceMock : public AllreduceRobust {
|
||||
return AllreduceRobust::LoadCheckPoint(&dum, &com);
|
||||
}
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||
double tstart = utils::GetTime();
|
||||
double tbet_chkpt = tstart - time_checkpoint;
|
||||
@@ -96,7 +96,7 @@ class AllreduceMock : public AllreduceRobust {
|
||||
tsum_allreduce = 0.0;
|
||||
}
|
||||
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "LazyCheckPoint");
|
||||
AllreduceRobust::LazyCheckPoint(global_model);
|
||||
}
|
||||
@@ -110,28 +110,28 @@ class AllreduceMock : public AllreduceRobust {
|
||||
double time_checkpoint;
|
||||
|
||||
private:
|
||||
struct DummySerializer : public ISerializable {
|
||||
virtual void Load(IStream &fi) {
|
||||
struct DummySerializer : public Serializable {
|
||||
virtual void Load(Stream *fi) {
|
||||
}
|
||||
virtual void Save(IStream &fo) const {
|
||||
virtual void Save(Stream *fo) const {
|
||||
}
|
||||
};
|
||||
struct ComboSerializer : public ISerializable {
|
||||
ISerializable *lhs;
|
||||
ISerializable *rhs;
|
||||
const ISerializable *c_lhs;
|
||||
const ISerializable *c_rhs;
|
||||
ComboSerializer(ISerializable *lhs, ISerializable *rhs)
|
||||
struct ComboSerializer : public Serializable {
|
||||
Serializable *lhs;
|
||||
Serializable *rhs;
|
||||
const Serializable *c_lhs;
|
||||
const Serializable *c_rhs;
|
||||
ComboSerializer(Serializable *lhs, Serializable *rhs)
|
||||
: lhs(lhs), rhs(rhs), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
ComboSerializer(const ISerializable *lhs, const ISerializable *rhs)
|
||||
ComboSerializer(const Serializable *lhs, const Serializable *rhs)
|
||||
: lhs(NULL), rhs(NULL), c_lhs(lhs), c_rhs(rhs) {
|
||||
}
|
||||
virtual void Load(IStream &fi) {
|
||||
virtual void Load(Stream *fi) {
|
||||
if (lhs != NULL) lhs->Load(fi);
|
||||
if (rhs != NULL) rhs->Load(fi);
|
||||
}
|
||||
virtual void Save(IStream &fo) const {
|
||||
virtual void Save(Stream *fo) const {
|
||||
if (c_lhs != NULL) c_lhs->Save(fo);
|
||||
if (c_rhs != NULL) c_rhs->Save(fo);
|
||||
}
|
||||
@@ -173,4 +173,4 @@ class AllreduceMock : public AllreduceRobust {
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
#endif // RABIT_ALLREDUCE_MOCK_H
|
||||
#endif // RABIT_ALLREDUCE_MOCK_H_
|
||||
|
||||
@@ -158,8 +158,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model) {
|
||||
int AllreduceRobust::LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model) {
|
||||
// skip action in single node
|
||||
if (world_size == 1) return 0;
|
||||
this->LocalModelCheck(local_model != NULL);
|
||||
@@ -175,7 +175,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
// load in local model
|
||||
utils::MemoryFixSizeBuffer fs(BeginPtr(local_chkpt[local_chkpt_version]),
|
||||
local_rptr[local_chkpt_version][1]);
|
||||
local_model->Load(fs);
|
||||
local_model->Load(&fs);
|
||||
} else {
|
||||
utils::Assert(nlocal == 0, "[%d] local model inconsistent, nlocal=%d", rank, nlocal);
|
||||
}
|
||||
@@ -189,7 +189,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
||||
} else {
|
||||
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0,
|
||||
"read in version number");
|
||||
global_model->Load(fs);
|
||||
global_model->Load(&fs);
|
||||
utils::Assert(local_model == NULL || nlocal == num_local_replica + 1,
|
||||
"local model inconsistent, nlocal=%d", nlocal);
|
||||
}
|
||||
@@ -241,8 +241,8 @@ void AllreduceRobust::LocalModelCheck(bool with_local) {
|
||||
*
|
||||
* \sa CheckPoint, LazyCheckPoint
|
||||
*/
|
||||
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
|
||||
const ISerializable *local_model,
|
||||
void AllreduceRobust::CheckPoint_(const Serializable *global_model,
|
||||
const Serializable *local_model,
|
||||
bool lazy_checkpt) {
|
||||
// never do check point in single machine mode
|
||||
if (world_size == 1) {
|
||||
@@ -261,7 +261,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
|
||||
local_chkpt[new_version].clear();
|
||||
utils::MemoryBufferStream fs(&local_chkpt[new_version]);
|
||||
if (local_model != NULL) {
|
||||
local_model->Save(fs);
|
||||
local_model->Save(&fs);
|
||||
}
|
||||
local_rptr[new_version].clear();
|
||||
local_rptr[new_version].push_back(0);
|
||||
@@ -287,7 +287,7 @@ void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
|
||||
global_checkpoint.resize(0);
|
||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
global_model->Save(fs);
|
||||
global_model->Save(&fs);
|
||||
global_lazycheck = NULL;
|
||||
}
|
||||
// reset result buffer
|
||||
@@ -748,7 +748,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
global_checkpoint.resize(0);
|
||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
global_lazycheck->Save(fs);
|
||||
global_lazycheck->Save(&fs);
|
||||
global_lazycheck = NULL;
|
||||
}
|
||||
// recover global checkpoint
|
||||
|
||||
@@ -80,8 +80,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL);
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
@@ -98,8 +98,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
this->CheckPoint_(global_model, local_model, false);
|
||||
}
|
||||
/*!
|
||||
@@ -122,7 +122,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* is the same in all nodes
|
||||
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
this->CheckPoint_(global_model, NULL, true);
|
||||
}
|
||||
/*!
|
||||
@@ -318,8 +318,8 @@ class AllreduceRobust : public AllreduceBase {
|
||||
*
|
||||
* \sa CheckPoint, LazyCheckPoint
|
||||
*/
|
||||
void CheckPoint_(const ISerializable *global_model,
|
||||
const ISerializable *local_model,
|
||||
void CheckPoint_(const Serializable *global_model,
|
||||
const Serializable *local_model,
|
||||
bool lazy_checkpt);
|
||||
/*!
|
||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||
@@ -521,7 +521,7 @@ o * the input state must exactly one saved state(local state of current node)
|
||||
// last check point global model
|
||||
std::string global_checkpoint;
|
||||
// lazy checkpoint of global model
|
||||
const ISerializable *global_lazycheck;
|
||||
const Serializable *global_lazycheck;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
// number of default local replica
|
||||
|
||||
@@ -37,15 +37,15 @@ class MPIEngine : public IEngine {
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("MPI is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(ISerializable *global_model,
|
||||
ISerializable *local_model = NULL) {
|
||||
virtual int LoadCheckPoint(Serializable *global_model,
|
||||
Serializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model = NULL) {
|
||||
virtual void CheckPoint(const Serializable *global_model,
|
||||
const Serializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||
virtual void LazyCheckPoint(const Serializable *global_model) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
|
||||
Reference in New Issue
Block a user