diff --git a/include/rabit.h b/include/rabit.h index 17ef5e616..eb1f0a07f 100644 --- a/include/rabit.h +++ b/include/rabit.h @@ -203,6 +203,27 @@ inline int LoadCheckPoint(ISerializable *global_model, */ inline void CheckPoint(const ISerializable *global_model, const ISerializable *local_model = NULL); +/*! + * \brief This function can be used to replace CheckPoint for global_model only, + * when certain condition is met(see detailed expplaination). + * + * This is a "lazy" checkpoint such that only the pointer to global_model is + * remembered and no memory copy is taken. To use this function, the user MUST ensure that: + * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. + * In another words, global_model model can be changed only between last call of + * Allreduce/Broadcast and LazyCheckPoint in current version + * + * For example, suppose the calling sequence is: + * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint + * + * If user can only changes global_model in code3, then LazyCheckPoint can be used to + * improve efficiency of the program. + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \sa LoadCheckPoint, CheckPoint, VersionNumber + */ +inline void LazyCheckPoint(const ISerializable *global_model); /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/include/rabit/engine.h b/include/rabit/engine.h index c06fbc6cc..fbbdaa8f0 100644 --- a/include/rabit/engine.h +++ b/include/rabit/engine.h @@ -114,6 +114,27 @@ class IEngine { */ virtual void CheckPoint(const ISerializable *global_model, const ISerializable *local_model = NULL) = 0; + /*! + * \brief This function can be used to replace CheckPoint for global_model only, + * when certain condition is met(see detailed expplaination. + * + * This is a "lazy" checkpoint such that only the pointer to global_model is + * remembered and no memory copy is taken. To use this function, the user MUST ensure that: + * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. + * In another words, global_model model can be changed only between last call of + * Allreduce/Broadcast and LazyCheckPoint in current version + * + * For example, suppose the calling sequence is: + * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint + * + * If user can only changes global_model in code3, then LazyCheckPoint can be used to + * improve efficiency of the program. + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \sa LoadCheckPoint, CheckPoint, VersionNumber + */ + virtual void LazyCheckPoint(const ISerializable *global_model) = 0; /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/include/rabit/rabit-inl.h b/include/rabit/rabit-inl.h index 3ba3ec95e..4ee1a42b5 100644 --- a/include/rabit/rabit-inl.h +++ b/include/rabit/rabit-inl.h @@ -183,6 +183,10 @@ inline void CheckPoint(const ISerializable *global_model, const ISerializable *local_model) { engine::GetEngine()->CheckPoint(global_model, local_model); } +// lazy checkpoint the model, only remember the pointer to global_model +inline void LazyCheckPoint(const ISerializable *global_model) { + engine::GetEngine()->LazyCheckPoint(global_model); +} // return the version number of currently stored model inline int VersionNumber(void) { return engine::GetEngine()->VersionNumber(); diff --git a/src/allreduce_base.h b/src/allreduce_base.h index da57c34f6..aaee59312 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -146,6 +146,29 @@ class AllreduceBase : public IEngine { const ISerializable *local_model = NULL) { version_number += 1; } + /*! + * \brief This function can be used to replace CheckPoint for global_model only, + * when certain condition is met(see detailed expplaination). + * + * This is a "lazy" checkpoint such that only the pointer to global_model is + * remembered and no memory copy is taken. To use this function, the user MUST ensure that: + * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. + * In another words, global_model model can be changed only between last call of + * Allreduce/Broadcast and LazyCheckPoint in current version + * + * For example, suppose the calling sequence is: + * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint + * + * If user can only changes global_model in code3, then LazyCheckPoint can be used to + * improve efficiency of the program. + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \sa LoadCheckPoint, CheckPoint, VersionNumber + */ + virtual void LazyCheckPoint(const ISerializable *global_model) { + version_number += 1; + } /*! * \return version number of current stored model, * which means how many calls to CheckPoint we made so far diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 7cb2b3611..90a0f4fac 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -25,6 +25,8 @@ AllreduceRobust::AllreduceRobust(void) { seq_counter = 0; local_chkpt_version = 0; result_buffer_round = 1; + global_lazycheck = NULL; + use_local_model = -1; } void AllreduceRobust::Init(void) { AllreduceBase::Init(); @@ -154,9 +156,7 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, ISerializable *local_model) { // skip action in single node if (world_size == 1) return 0; - if (local_model != NULL && num_local_replica == 0) { - num_local_replica = default_local_replica; - } + this->LocalModelCheck(local_model != NULL); if (num_local_replica == 0) { utils::Check(local_model == NULL, "need to set rabit_local_replica larger than 1 to checkpoint local_model"); @@ -199,30 +199,50 @@ int AllreduceRobust::LoadCheckPoint(ISerializable *global_model, } } /*! - * \brief checkpoint the model, meaning we finished a stage of execution - * every time we call check point, there is a version number which will increase by one + * \brief internal consistency check function, + * use check to ensure user always call CheckPoint/LoadCheckPoint + * with or without local but not both, this function will set the approperiate settings + * in the first call of LoadCheckPoint/CheckPoint * + * \param with_local whether the user calls CheckPoint with local model + */ +void AllreduceRobust::LocalModelCheck(bool with_local) { + if (use_local_model == -1) { + if (with_local) { + use_local_model = 1; + if (num_local_replica == 0) { + num_local_replica = default_local_replica; + } + } else { + use_local_model = 0; + num_local_replica = 0; + } + } else { + utils::Check(use_local_model == int(with_local), + "Can only call Checkpoint/LoadCheckPoint always with"\ + "or without local_model, but not mixed case"); + } +} +/*! + * \brief internal implementation of checkpoint, support both lazy and normal way + * * \param global_model pointer to the globally shared model/state * when calling this function, the caller need to gauranttees that global_model * is the same in all nodes * \param local_model pointer to local model, that is specific to current node/rank * this can be NULL when no local state is needed + * \param lazy_checkpt whether the action is lazy checkpoint * - * NOTE: local_model requires explicit replication of the model for fault-tolerance, which will - * bring replication cost in CheckPoint function. global_model do not need explicit replication. - * So only CheckPoint with global_model if possible - * - * \sa LoadCheckPoint, VersionNumber + * \sa CheckPoint, LazyCheckPoint */ -void AllreduceRobust::CheckPoint(const ISerializable *global_model, - const ISerializable *local_model) { +void AllreduceRobust::CheckPoint_(const ISerializable *global_model, + const ISerializable *local_model, + bool lazy_checkpt) { // never do check point in single machine mode if (world_size == 1) { version_number += 1; return; } - if (local_model != NULL && num_local_replica == 0) { - num_local_replica = default_local_replica; - } + this->LocalModelCheck(local_model != NULL); if (num_local_replica == 0) { utils::Check(local_model == NULL, "need to set rabit_local_replica larger than 1 to checkpoint local_model"); @@ -255,10 +275,15 @@ void AllreduceRobust::CheckPoint(const ISerializable *global_model, // increase version number version_number += 1; // save model - global_checkpoint.resize(0); - utils::MemoryBufferStream fs(&global_checkpoint); - fs.Write(&version_number, sizeof(version_number)); - global_model->Save(fs); + if (lazy_checkpt) { + global_lazycheck = global_model; + } else { + global_checkpoint.resize(0); + utils::MemoryBufferStream fs(&global_checkpoint); + fs.Write(&version_number, sizeof(version_number)); + global_model->Save(fs); + global_lazycheck = NULL; + } // reset result buffer resbuf.Clear(); seq_counter = 0; // execute check ack step, load happens here @@ -698,6 +723,14 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { utils::Check(state == 1 || state == 2, "LoadCheckPoint: too many nodes fails, cannot recover local state"); } + // do call save model if the checkpoint was lazy + if (role == kHaveData && global_lazycheck != NULL) { + global_checkpoint.resize(0); + utils::MemoryBufferStream fs(&global_checkpoint); + fs.Write(&version_number, sizeof(version_number)); + global_lazycheck->Save(fs); + global_lazycheck = NULL; + } // recover global checkpoint size_t size = this->global_checkpoint.length(); int recv_link; diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 921f18319..078ff1598 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -99,7 +99,32 @@ class AllreduceRobust : public AllreduceBase { * \sa LoadCheckPoint, VersionNumber */ virtual void CheckPoint(const ISerializable *global_model, - const ISerializable *local_model = NULL); + const ISerializable *local_model = NULL) { + this->CheckPoint_(global_model, local_model, false); + } + /*! + * \brief This function can be used to replace CheckPoint for global_model only, + * when certain condition is met(see detailed expplaination). + * + * This is a "lazy" checkpoint such that only the pointer to global_model is + * remembered and no memory copy is taken. To use this function, the user MUST ensure that: + * The global_model must remain unchanged util last call of Allreduce/Broadcast in current version finishs. + * In another words, global_model model can be changed only between last call of + * Allreduce/Broadcast and LazyCheckPoint in current version + * + * For example, suppose the calling sequence is: + * LazyCheckPoint, code1, Allreduce, code2, Broadcast, code3, LazyCheckPoint + * + * If user can only changes global_model in code3, then LazyCheckPoint can be used to + * improve efficiency of the program. + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \sa LoadCheckPoint, CheckPoint, VersionNumber + */ + virtual void LazyCheckPoint(const ISerializable *global_model) { + this->CheckPoint_(global_model, NULL, true); + } /*! * \brief explicitly re-init everything before calling LoadCheckPoint * call this function when IEngine throw an exception out, @@ -274,10 +299,38 @@ class AllreduceRobust : public AllreduceBase { std::vector data_; }; /*! - * \brief reset the all the existing links by sending Out-of-Band message marker - * after this function finishes, all the messages received and sent before in all live links are discarded, - * This allows us to get a fresh start after error has happened + * \brief internal consistency check function, + * use check to ensure user always call CheckPoint/LoadCheckPoint + * with or without local but not both, this function will set the approperiate settings + * in the first call of LoadCheckPoint/CheckPoint * + * \param with_local whether the user calls CheckPoint with local model + */ + void LocalModelCheck(bool with_local); + /*! + * \brief internal implementation of checkpoint, support both lazy and normal way + * + * \param global_model pointer to the globally shared model/state + * when calling this function, the caller need to gauranttees that global_model + * is the same in all nodes + * \param local_model pointer to local model, that is specific to current node/rank + * this can be NULL when no local state is needed + * \param lazy_checkpt whether the action is lazy checkpoint + * + * \sa CheckPoint, LazyCheckPoint + */ + void CheckPoint_(const ISerializable *global_model, + const ISerializable *local_model, + bool lazy_checkpt); + /*! + * \brief reset the all the existing links by sending Out-of-Band message marker + * after this function finishes, all the messages received and sent + * before in all live links are discarded, + * This allows us to get a fresh start after error has happened + * + * TODO(tqchen): this function is not yet functioning was not used by engine, + * simple resetlink and reconnect strategy is used + * * \return this function can return kSuccess or kSockError * when kSockError is returned, it simply means there are bad sockets in the links, * and some link recovery proceduer is needed @@ -468,10 +521,14 @@ o * the input state must exactly one saved state(local state of current node) ResultBuffer resbuf; // last check point global model std::string global_checkpoint; + // lazy checkpoint of global model + const ISerializable *global_lazycheck; // number of replica for local state/model int num_local_replica; // number of default local replica int default_local_replica; + // flag to decide whether local model is used, -1: unknown, 0: no, 1:yes + int use_local_model; // number of replica for global state/model int num_global_replica; // --- recovery data structure for local checkpoint diff --git a/src/engine_empty.cc b/src/engine_empty.cc index 0c7020914..3d14e3ef3 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -42,6 +42,9 @@ class EmptyEngine : public IEngine { const ISerializable *local_model = NULL) { version_number += 1; } + virtual void LazyCheckPoint(const ISerializable *global_model) { + version_number += 1; + } virtual int VersionNumber(void) const { return version_number; } diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index c1b723572..9c6206ebf 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -45,6 +45,9 @@ class MPIEngine : public IEngine { const ISerializable *local_model = NULL) { version_number += 1; } + virtual void LazyCheckPoint(const ISerializable *global_model) { + version_number += 1; + } virtual int VersionNumber(void) const { return version_number; } diff --git a/test/test.mk b/test/test.mk index efc5b418c..05085bdfc 100644 --- a/test/test.mk +++ b/test/test.mk @@ -1,13 +1,8 @@ -ifndef $(nslave) - nslave=2 -endif -ifndef $(ndata) - ndata=10 -endif - # this is a makefile used to show testcases of rabit -.PHONY: model_recover local_recover speed +.PHONY: +test: + ../tracker/rabit_mpi.py -v 1 -n 10 bash keepalive.sh test_model_recover 1 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=8,1,2,0 # this experiment test recovery with actually process exit, use keepalive to keep program alive model_recover_10_10k: @@ -18,3 +13,4 @@ model_recover_10_10k_die_same: model_recover_10_10k_die_hard: ../tracker/rabit_demo.py -n 10 test_model_recover 10000 mock=0,0,1,0 mock=1,1,1,0 mock=1,1,1,1 mock=0,1,1,0 mock=4,1,1,0 mock=9,1,1,0 mock=8,1,2,0 mock=4,1,3,0 +