add lazy check, need test, find a race condition

This commit is contained in:
tqchen
2015-01-14 11:58:43 -08:00
parent bddfa2fc24
commit 87c7817124
9 changed files with 192 additions and 31 deletions

View File

@@ -146,6 +146,29 @@ class AllreduceBase : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@@ -25,6 +25,8 @@ AllreduceRobust::AllreduceRobust(void) {
seq_counter = 0;
local_chkpt_version = 0;
result_buffer_round = 1;
global_lazycheck = NULL;
use_local_model = -1;
}
void AllreduceRobust::Init(void) {
AllreduceBase::Init();
@@ -154,9 +156,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
// skip action in single node
if (world_size == 1) return 0;
if (local_model != NULL && num_local_replica == 0) {
num_local_replica = default_local_replica;
}
this->LocalModelCheck(local_model != NULL);
if (num_local_replica == 0) {
utils::Check(local_model == NULL,
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
@@ -199,30 +199,50 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
}
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void AllreduceRobust::LocalModelCheck(bool with_local) {
if (use_local_model == -1) {
if (with_local) {
use_local_model = 1;
if (num_local_replica == 0) {
num_local_replica = default_local_replica;
}
} else {
use_local_model = 0;
num_local_replica = 0;
}
} else {
utils::Check(use_local_model == int(with_local),
"Can only call Checkpoint/LoadCheckPoint always with"\
"or without local_model, but not mixed case");
}
}
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
* \sa CheckPoint, LazyCheckPoint
*/
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
bool lazy_checkpt) {
// never do check point in single machine mode
if (world_size == 1) {
version_number += 1; return;
}
if (local_model != NULL && num_local_replica == 0) {
num_local_replica = default_local_replica;
}
this->LocalModelCheck(local_model != NULL);
if (num_local_replica == 0) {
utils::Check(local_model == NULL,
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
@@ -255,10 +275,15 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
// increase version number
version_number += 1;
// save model
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
if (lazy_checkpt) {
global_lazycheck = global_model;
} else {
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
global_lazycheck = NULL;
}
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
@@ -698,6 +723,14 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
utils::Check(state == 1 || state == 2,
"LoadCheckPoint: too many nodes fails, cannot recover local state");
}
// do call save model if the checkpoint was lazy
if (role == kHaveData && global_lazycheck != NULL) {
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_lazycheck->Save(fs);
global_lazycheck = NULL;
}
// recover global checkpoint
size_t size = this->global_checkpoint.length();
int recv_link;

View File

@@ -99,7 +99,32 @@ class AllreduceRobust : public AllreduceBase {
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
const ISerializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
@@ -274,10 +299,38 @@ class AllreduceRobust : public AllreduceBase {
std::vector<uint64_t> data_;
};
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent before in all live links are discarded,
* This allows us to get a fresh start after error has happened
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void LocalModelCheck(bool with_local);
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
* before in all live links are discarded,
* This allows us to get a fresh start after error has happened
*
* TODO(tqchen): this function is not yet functioning was not used by engine,
* simple resetlink and reconnect strategy is used
*
* \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
@@ -468,10 +521,14 @@ o * the input state must exactly one saved state(local state of current node)
ResultBuffer resbuf;
// last check point global model
std::string global_checkpoint;
// lazy checkpoint of global model
const ISerializable *global_lazycheck;
// number of replica for local state/model
int num_local_replica;
// number of default local replica
int default_local_replica;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
// number of replica for global state/model
int num_global_replica;
// --- recovery data structure for local checkpoint

View File

@@ -42,6 +42,9 @@ class EmptyEngine : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}

View File

@@ -45,6 +45,9 @@ class MPIEngine : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}