add lazy check, need test, find a race condition

This commit is contained in:
tqchen 2015-01-14 11:58:43 -08:00
parent bddfa2fc24
commit 87c7817124
9 changed files with 192 additions and 31 deletions

View File

@ -203,6 +203,27 @@ inline int LoadCheckPoint(ISerializable *global_model,
*/
inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
inline void LazyCheckPoint(const ISerializable *global_model);
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@ -114,6 +114,27 @@ class IEngine {
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) = 0;
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination.
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@ -183,6 +183,10 @@ inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model);
}
// lazy checkpoint the model, only remember the pointer to global_model
inline void LazyCheckPoint(const ISerializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model);
}
// return the version number of currently stored model
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();

View File

@ -146,6 +146,29 @@ class AllreduceBase : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@ -25,6 +25,8 @@ AllreduceRobust::AllreduceRobust(void) {
seq_counter = 0;
local_chkpt_version = 0;
result_buffer_round = 1;
global_lazycheck = NULL;
use_local_model = -1;
}
void AllreduceRobust::Init(void) {
AllreduceBase::Init();
@ -154,9 +156,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
ISerializable *local_model) {
// skip action in single node
if (world_size == 1) return 0;
if (local_model != NULL && num_local_replica == 0) {
num_local_replica = default_local_replica;
}
this->LocalModelCheck(local_model != NULL);
if (num_local_replica == 0) {
utils::Check(local_model == NULL,
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
@ -199,30 +199,50 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
}
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void AllreduceRobust::LocalModelCheck(bool with_local) {
if (use_local_model == -1) {
if (with_local) {
use_local_model = 1;
if (num_local_replica == 0) {
num_local_replica = default_local_replica;
}
} else {
use_local_model = 0;
num_local_replica = 0;
}
} else {
utils::Check(use_local_model == int(with_local),
"Can only call Checkpoint/LoadCheckPoint always with"\
"or without local_model, but not mixed case");
}
}
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
* \sa CheckPoint, LazyCheckPoint
*/
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
bool lazy_checkpt) {
// never do check point in single machine mode
if (world_size == 1) {
version_number += 1; return;
}
if (local_model != NULL && num_local_replica == 0) {
num_local_replica = default_local_replica;
}
this->LocalModelCheck(local_model != NULL);
if (num_local_replica == 0) {
utils::Check(local_model == NULL,
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
@ -255,10 +275,15 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
// increase version number
version_number += 1;
// save model
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
if (lazy_checkpt) {
global_lazycheck = global_model;
} else {
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
global_lazycheck = NULL;
}
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
@ -698,6 +723,14 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
utils::Check(state == 1 || state == 2,
"LoadCheckPoint: too many nodes fails, cannot recover local state");
}
// do call save model if the checkpoint was lazy
if (role == kHaveData && global_lazycheck != NULL) {
global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number));
global_lazycheck->Save(fs);
global_lazycheck = NULL;
}
// recover global checkpoint
size_t size = this->global_checkpoint.length();
int recv_link;

View File

@ -99,7 +99,32 @@ class AllreduceRobust : public AllreduceBase {
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL);
const ISerializable *local_model = NULL) {
this->CheckPoint_(global_model, local_model, false);
}
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination).
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) {
this->CheckPoint_(global_model, NULL, true);
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
@ -274,10 +299,38 @@ class AllreduceRobust : public AllreduceBase {
std::vector<uint64_t> data_;
};
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent before in all live links are discarded,
* This allows us to get a fresh start after error has happened
* \brief internal consistency check function,
* use check to ensure user always call CheckPoint/LoadCheckPoint
* with or without local but not both, this function will set the approperiate settings
* in the first call of LoadCheckPoint/CheckPoint
*
* \param with_local whether the user calls CheckPoint with local model
*/
void LocalModelCheck(bool with_local);
/*!
* \brief internal implementation of checkpoint, support both lazy and normal way
*
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
* \param lazy_checkpt whether the action is lazy checkpoint
*
* \sa CheckPoint, LazyCheckPoint
*/
void CheckPoint_(const ISerializable *global_model,
const ISerializable *local_model,
bool lazy_checkpt);
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent
* before in all live links are discarded,
* This allows us to get a fresh start after error has happened
*
* TODO(tqchen): this function is not yet functioning was not used by engine,
* simple resetlink and reconnect strategy is used
*
* \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
@ -468,10 +521,14 @@ o * the input state must exactly one saved state(local state of current node)
ResultBuffer resbuf;
// last check point global model
std::string global_checkpoint;
// lazy checkpoint of global model
const ISerializable *global_lazycheck;
// number of replica for local state/model
int num_local_replica;
// number of default local replica
int default_local_replica;
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
int use_local_model;
// number of replica for global state/model
int num_global_replica;
// --- recovery data structure for local checkpoint

View File

@ -42,6 +42,9 @@ class EmptyEngine : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}

View File

@ -45,6 +45,9 @@ class MPIEngine : public IEngine {
const ISerializable *local_model = NULL) {
version_number += 1;
}
virtual void LazyCheckPoint(const ISerializable *global_model) {
version_number += 1;
}
virtual int VersionNumber(void) const {
return version_number;
}

View File

@ -1,13 +1,8 @@
ifndef $(nslave)
nslave=2
endif
ifndef $(ndata)
ndata=10
endif
# this is a makefile used to show testcases of rabit
.PHONY: model_recover local_recover speed
.PHONY:
test:
../tracker/rabit_mpi.py -v 1 -n 10 bash keepalive.sh test_model_recover 1 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=8,1,2,0
# this experiment test recovery with actually process exit, use keepalive to keep program alive
model_recover_10_10k:
@ -18,3 +13,4 @@ model_recover_10_10k_die_same:
model_recover_10_10k_die_hard:
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0