add lazy check, need test, find a race condition
This commit is contained in:
parent
bddfa2fc24
commit
87c7817124
@ -203,6 +203,27 @@ inline int LoadCheckPoint(ISerializable *global_model,
|
|||||||
*/
|
*/
|
||||||
inline void CheckPoint(const ISerializable *global_model,
|
inline void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL);
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
inline void LazyCheckPoint(const ISerializable *global_model);
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -114,6 +114,27 @@ class IEngine {
|
|||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL) = 0;
|
const ISerializable *local_model = NULL) = 0;
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination.
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -183,6 +183,10 @@ inline void CheckPoint(const ISerializable *global_model,
|
|||||||
const ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
engine::GetEngine()->CheckPoint(global_model, local_model);
|
engine::GetEngine()->CheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
|
// lazy checkpoint the model, only remember the pointer to global_model
|
||||||
|
inline void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
engine::GetEngine()->LazyCheckPoint(global_model);
|
||||||
|
}
|
||||||
// return the version number of currently stored model
|
// return the version number of currently stored model
|
||||||
inline int VersionNumber(void) {
|
inline int VersionNumber(void) {
|
||||||
return engine::GetEngine()->VersionNumber();
|
return engine::GetEngine()->VersionNumber();
|
||||||
|
|||||||
@ -146,6 +146,29 @@ class AllreduceBase : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -25,6 +25,8 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
local_chkpt_version = 0;
|
local_chkpt_version = 0;
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
use_local_model = -1;
|
||||||
}
|
}
|
||||||
void AllreduceRobust::Init(void) {
|
void AllreduceRobust::Init(void) {
|
||||||
AllreduceBase::Init();
|
AllreduceBase::Init();
|
||||||
@ -154,9 +156,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
|||||||
ISerializable *local_model) {
|
ISerializable *local_model) {
|
||||||
// skip action in single node
|
// skip action in single node
|
||||||
if (world_size == 1) return 0;
|
if (world_size == 1) return 0;
|
||||||
if (local_model != NULL && num_local_replica == 0) {
|
this->LocalModelCheck(local_model != NULL);
|
||||||
num_local_replica = default_local_replica;
|
|
||||||
}
|
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
@ -199,30 +199,50 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief internal consistency check function,
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
|
*/
|
||||||
|
void AllreduceRobust::LocalModelCheck(bool with_local) {
|
||||||
|
if (use_local_model == -1) {
|
||||||
|
if (with_local) {
|
||||||
|
use_local_model = 1;
|
||||||
|
if (num_local_replica == 0) {
|
||||||
|
num_local_replica = default_local_replica;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
use_local_model = 0;
|
||||||
|
num_local_replica = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
utils::Check(use_local_model == int(with_local),
|
||||||
|
"Can only call Checkpoint/LoadCheckPoint always with"\
|
||||||
|
"or without local_model, but not mixed case");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
|
*
|
||||||
* \param global_model pointer to the globally shared model/state
|
* \param global_model pointer to the globally shared model/state
|
||||||
* when calling this function, the caller need to gauranttees that global_model
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
* is the same in all nodes
|
* is the same in all nodes
|
||||||
* \param local_model pointer to local model, that is specific to current node/rank
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
* this can be NULL when no local state is needed
|
* this can be NULL when no local state is needed
|
||||||
|
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||||
*
|
*
|
||||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
* \sa CheckPoint, LazyCheckPoint
|
||||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
|
||||||
* So only CheckPoint with global_model if possible
|
|
||||||
*
|
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
|
||||||
*/
|
*/
|
||||||
void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
void AllreduceRobust::CheckPoint_(const ISerializable *global_model,
|
||||||
const ISerializable *local_model) {
|
const ISerializable *local_model,
|
||||||
|
bool lazy_checkpt) {
|
||||||
// never do check point in single machine mode
|
// never do check point in single machine mode
|
||||||
if (world_size == 1) {
|
if (world_size == 1) {
|
||||||
version_number += 1; return;
|
version_number += 1; return;
|
||||||
}
|
}
|
||||||
if (local_model != NULL && num_local_replica == 0) {
|
this->LocalModelCheck(local_model != NULL);
|
||||||
num_local_replica = default_local_replica;
|
|
||||||
}
|
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL,
|
utils::Check(local_model == NULL,
|
||||||
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
"need to set rabit_local_replica larger than 1 to checkpoint local_model");
|
||||||
@ -255,10 +275,15 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model,
|
|||||||
// increase version number
|
// increase version number
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
// save model
|
// save model
|
||||||
global_checkpoint.resize(0);
|
if (lazy_checkpt) {
|
||||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
global_lazycheck = global_model;
|
||||||
fs.Write(&version_number, sizeof(version_number));
|
} else {
|
||||||
global_model->Save(fs);
|
global_checkpoint.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||||
|
fs.Write(&version_number, sizeof(version_number));
|
||||||
|
global_model->Save(fs);
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
}
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
@ -698,6 +723,14 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
|||||||
utils::Check(state == 1 || state == 2,
|
utils::Check(state == 1 || state == 2,
|
||||||
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
"LoadCheckPoint: too many nodes fails, cannot recover local state");
|
||||||
}
|
}
|
||||||
|
// do call save model if the checkpoint was lazy
|
||||||
|
if (role == kHaveData && global_lazycheck != NULL) {
|
||||||
|
global_checkpoint.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||||
|
fs.Write(&version_number, sizeof(version_number));
|
||||||
|
global_lazycheck->Save(fs);
|
||||||
|
global_lazycheck = NULL;
|
||||||
|
}
|
||||||
// recover global checkpoint
|
// recover global checkpoint
|
||||||
size_t size = this->global_checkpoint.length();
|
size_t size = this->global_checkpoint.length();
|
||||||
int recv_link;
|
int recv_link;
|
||||||
|
|||||||
@ -99,7 +99,32 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model = NULL);
|
const ISerializable *local_model = NULL) {
|
||||||
|
this->CheckPoint_(global_model, local_model, false);
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief This function can be used to replace CheckPoint for global_model only,
|
||||||
|
* when certain condition is met(see detailed expplaination).
|
||||||
|
*
|
||||||
|
* This is a "lazy" checkpoint such that only the pointer to global_model is
|
||||||
|
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
|
||||||
|
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
|
||||||
|
* In another words, global_model model can be changed only between last call of
|
||||||
|
* Allreduce/Broadcast and LazyCheckPoint in current version
|
||||||
|
*
|
||||||
|
* For example, suppose the calling sequence is:
|
||||||
|
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
|
||||||
|
*
|
||||||
|
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
|
||||||
|
* improve efficiency of the program.
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \sa LoadCheckPoint, CheckPoint, VersionNumber
|
||||||
|
*/
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
this->CheckPoint_(global_model, NULL, true);
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
@ -274,10 +299,38 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
std::vector<uint64_t> data_;
|
std::vector<uint64_t> data_;
|
||||||
};
|
};
|
||||||
/*!
|
/*!
|
||||||
* \brief reset the all the existing links by sending Out-of-Band message marker
|
* \brief internal consistency check function,
|
||||||
* after this function finishes, all the messages received and sent before in all live links are discarded,
|
* use check to ensure user always call CheckPoint/LoadCheckPoint
|
||||||
* This allows us to get a fresh start after error has happened
|
* with or without local but not both, this function will set the approperiate settings
|
||||||
|
* in the first call of LoadCheckPoint/CheckPoint
|
||||||
*
|
*
|
||||||
|
* \param with_local whether the user calls CheckPoint with local model
|
||||||
|
*/
|
||||||
|
void LocalModelCheck(bool with_local);
|
||||||
|
/*!
|
||||||
|
* \brief internal implementation of checkpoint, support both lazy and normal way
|
||||||
|
*
|
||||||
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
* \param lazy_checkpt whether the action is lazy checkpoint
|
||||||
|
*
|
||||||
|
* \sa CheckPoint, LazyCheckPoint
|
||||||
|
*/
|
||||||
|
void CheckPoint_(const ISerializable *global_model,
|
||||||
|
const ISerializable *local_model,
|
||||||
|
bool lazy_checkpt);
|
||||||
|
/*!
|
||||||
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
||||||
|
* after this function finishes, all the messages received and sent
|
||||||
|
* before in all live links are discarded,
|
||||||
|
* This allows us to get a fresh start after error has happened
|
||||||
|
*
|
||||||
|
* TODO(tqchen): this function is not yet functioning was not used by engine,
|
||||||
|
* simple resetlink and reconnect strategy is used
|
||||||
|
*
|
||||||
* \return this function can return kSuccess or kSockError
|
* \return this function can return kSuccess or kSockError
|
||||||
* when kSockError is returned, it simply means there are bad sockets in the links,
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
||||||
* and some link recovery proceduer is needed
|
* and some link recovery proceduer is needed
|
||||||
@ -468,10 +521,14 @@ o * the input state must exactly one saved state(local state of current node)
|
|||||||
ResultBuffer resbuf;
|
ResultBuffer resbuf;
|
||||||
// last check point global model
|
// last check point global model
|
||||||
std::string global_checkpoint;
|
std::string global_checkpoint;
|
||||||
|
// lazy checkpoint of global model
|
||||||
|
const ISerializable *global_lazycheck;
|
||||||
// number of replica for local state/model
|
// number of replica for local state/model
|
||||||
int num_local_replica;
|
int num_local_replica;
|
||||||
// number of default local replica
|
// number of default local replica
|
||||||
int default_local_replica;
|
int default_local_replica;
|
||||||
|
// flag to decide whether local model is used, -1: unknown, 0: no, 1:yes
|
||||||
|
int use_local_model;
|
||||||
// number of replica for global state/model
|
// number of replica for global state/model
|
||||||
int num_global_replica;
|
int num_global_replica;
|
||||||
// --- recovery data structure for local checkpoint
|
// --- recovery data structure for local checkpoint
|
||||||
|
|||||||
@ -42,6 +42,9 @@ class EmptyEngine : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
return version_number;
|
return version_number;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -45,6 +45,9 @@ class MPIEngine : public IEngine {
|
|||||||
const ISerializable *local_model = NULL) {
|
const ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
|
virtual void LazyCheckPoint(const ISerializable *global_model) {
|
||||||
|
version_number += 1;
|
||||||
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
return version_number;
|
return version_number;
|
||||||
}
|
}
|
||||||
|
|||||||
12
test/test.mk
12
test/test.mk
@ -1,13 +1,8 @@
|
|||||||
ifndef $(nslave)
|
|
||||||
nslave=2
|
|
||||||
endif
|
|
||||||
ifndef $(ndata)
|
|
||||||
ndata=10
|
|
||||||
endif
|
|
||||||
|
|
||||||
# this is a makefile used to show testcases of rabit
|
# this is a makefile used to show testcases of rabit
|
||||||
.PHONY: model_recover local_recover speed
|
.PHONY:
|
||||||
|
|
||||||
|
test:
|
||||||
|
../tracker/rabit_mpi.py -v 1 -n 10 bash keepalive.sh test_model_recover 1 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=8,1,2,0
|
||||||
|
|
||||||
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
# this experiment test recovery with actually process exit, use keepalive to keep program alive
|
||||||
model_recover_10_10k:
|
model_recover_10_10k:
|
||||||
@ -18,3 +13,4 @@ model_recover_10_10k_die_same:
|
|||||||
|
|
||||||
model_recover_10_10k_die_hard:
|
model_recover_10_10k_die_hard:
|
||||||
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user