add lazy check, need test, find a race condition

This commit is contained in:
tqchen
2015-01-14 11:58:43 -08:00
parent bddfa2fc24
commit 87c7817124
9 changed files with 192 additions and 31 deletions

View File

@@ -114,6 +114,27 @@ class IEngine {
*/
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model = NULL) = 0;
/*!
* \brief This function can be used to replace CheckPoint for global_model only,
* when certain condition is met(see detailed expplaination.
*
* This is a "lazy" checkpoint such that only the pointer to global_model is
* remembered and no memory copy is taken. To use this function, the user MUST ensure that:
* The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs.
* In another words, global_model model can be changed only between last call of
* Allreduce/Broadcast and LazyCheckPoint in current version
*
* For example, suppose the calling sequence is:
* LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint
*
* If user can only changes global_model in code3, then LazyCheckPoint can be used to
* improve efficiency of the program.
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \sa LoadCheckPoint, CheckPoint, VersionNumber
*/
virtual void LazyCheckPoint(const ISerializable *global_model) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@@ -183,6 +183,10 @@ inline void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
engine::GetEngine()->CheckPoint(global_model, local_model);
}
// lazy checkpoint the model, only remember the pointer to global_model
inline void LazyCheckPoint(const ISerializable *global_model) {
engine::GetEngine()->LazyCheckPoint(global_model);
}
// return the version number of currently stored model
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();