diff --git a/src/engine.h b/src/engine.h index 807f7c6ad..1c040a9e4 100644 --- a/src/engine.h +++ b/src/engine.h @@ -61,15 +61,35 @@ class IEngine { /*! * \brief load latest check point * \param p_model pointer to the model - * \return true if there was stored checkpoint and load was successful - * false if there was no stored checkpoint, means we are start over gain + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * the p_model is not touched, user should do necessary initialization by themselves + * + * Common usage example: + * int iter = rabit::LoadCheckPoint(&model); + * if (iter == 0) model.InitParameters(); + * for (i = iter; i < max_iter; ++i) { + * do many things, include allreduce + * rabit::CheckPoint(model); + * } + * + * \sa CheckPoint, VersionNumber */ - virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0; + virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0; /*! * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * * \param p_model pointer to the model + * \sa LoadCheckPoint, VersionNumber */ virtual void CheckPoint(const utils::ISerializable &model) = 0; + /*! + * \return version number of current stored model, + * which means how many calls to CheckPoint we made so far + * \sa LoadCheckPoint, CheckPoint + */ + virtual int VersionNumber(void) const = 0; /*! \brief get rank of current node */ virtual int GetRank(void) const = 0; /*! \brief get total number of */ diff --git a/src/engine_base.cc b/src/engine_base.cc index 3b08d1502..556b71e08 100644 --- a/src/engine_base.cc +++ b/src/engine_base.cc @@ -21,6 +21,7 @@ AllReduceBase::AllReduceBase(void) { nport_trial = 1000; rank = 0; world_size = 1; + version_number = 0; this->SetParam("reduce_buffer", "256MB"); } diff --git a/src/engine_base.h b/src/engine_base.h index 9e533fe27..48d38aeb9 100644 --- a/src/engine_base.h +++ b/src/engine_base.h @@ -35,10 +35,10 @@ class AllReduceBase : public IEngine { // constant one byte out of band message to indicate error happening AllReduceBase(void); virtual ~AllReduceBase(void) {} - // shutdown the engine - void Shutdown(void); // initialize the manager void Init(void); + // shutdown the engine + virtual void Shutdown(void); /*! * \brief set parameters to the engine * \param name parameter name @@ -82,20 +82,34 @@ class AllReduceBase : public IEngine { utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, "AllReduce failed"); } - /*! + /*! * \brief load latest check point * \param p_model pointer to the model - * \return true if there was stored checkpoint and load was successful - * false if there was no stored checkpoint, means we are start over gain - */ - virtual bool LoadCheckPoint(utils::ISerializable *p_model) { - return false; + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * the p_model is not touched, user should do necessary initialization by themselves + * \sa CheckPoint, VersionNumber + */ + virtual int LoadCheckPoint(utils::ISerializable *p_model) { + return 0; } /*! * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * * \param p_model pointer to the model + * \sa LoadCheckPoint, VersionNumber */ virtual void CheckPoint(const utils::ISerializable &model) { + version_number += 1; + } + /*! + * \return version number of current stored model, + * which means how many calls to CheckPoint we made so far + * \sa LoadCheckPoint, CheckPoint + */ + virtual int VersionNumber(void) const { + return version_number; } /*! * \brief explicitly re-init everything before calling LoadCheckPoint @@ -236,6 +250,8 @@ class AllReduceBase : public IEngine { * \sa ReturnType */ ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); + //---- data structure related to model ---- + int version_number; //---- local data related to link ---- // index of parent link, can be -1, meaning this is root of the tree int parent_index; diff --git a/src/engine_robust.cc b/src/engine_robust.cc index ab33b0f0a..59a5b79a3 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -16,9 +16,27 @@ namespace rabit { namespace engine { AllReduceRobust::AllReduceRobust(void) { - result_buffer_round = 2; + result_buffer_round = 1; seq_counter = 0; } +/*! \brief shutdown the engine */ +void AllReduceRobust::Shutdown(void) { + // need to sync the exec before we shutdown, do a pesudo check point + // execute checkpoint, note: when checkpoint existing, load will not happen + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), + "check point must return true"); + // reset result buffer + resbuf.Clear(); seq_counter = 0; + // execute check ack step, load happens here + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), + "check ack must return true"); + AllReduceBase::Shutdown(); +} +/*! + * \brief set parameters to the engine + * \param name parameter name + * \param val parameter value + */ void AllReduceRobust::SetParam(const char *name, const char *val) { AllReduceBase::SetParam(name, val); if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val); @@ -91,24 +109,25 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) /*! * \brief load latest check point * \param p_model pointer to the model - * \return true if there was stored checkpoint and load was successful - * false if there was no stored checkpoint, means we are start over gain + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * the p_model is not touched, user should do necessary initialization by themselves + * \sa CheckPoint, VersionNumber */ -bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { +int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { // check if we succesfll if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) { // reset result buffer resbuf.Clear(); seq_counter = 0; - // if loaded model is empty, this simply means we did not call checkpoint yet - // ask caller to reinit model - if (checked_model.length() == 0) return false; // load from buffer utils::MemoryBufferStream fs(&checked_model); + fs.Read(&version_number, sizeof(version_number)); + if (version_number == 0) return version_number; p_model->Load(fs); // run another phase of check ack, if recovered from data utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), "check ack must return true"); - return true; + return version_number; } else { // reset result buffer resbuf.Clear(); seq_counter = 0; @@ -118,14 +137,19 @@ bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { } /*! * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * * \param p_model pointer to the model + * \sa LoadCheckPoint, VersionNumber */ void AllReduceRobust::CheckPoint(const utils::ISerializable &model) { + // increase version number + version_number += 1; // save model checked_model.resize(0); utils::MemoryBufferStream fs(&checked_model); + fs.Write(&version_number, sizeof(version_number)); model.Save(fs); - utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something"); // execute checkpoint, note: when checkpoint existing, load will not happen utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), "check point must return true"); @@ -586,7 +610,7 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) { utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations"); } // request - ActionSummary req(flag, seqno); + ActionSummary req(flag, seqno); while (true) { // action ActionSummary act = req; diff --git a/src/engine_robust.h b/src/engine_robust.h index 7116764d8..32aee1f2b 100644 --- a/src/engine_robust.h +++ b/src/engine_robust.h @@ -20,6 +20,8 @@ class AllReduceRobust : public AllReduceBase { public: AllReduceRobust(void); virtual ~AllReduceRobust(void) {} + /*! \brief shutdown the engine */ + virtual void Shutdown(void); /*! * \brief set parameters to the engine * \param name parameter name @@ -45,18 +47,23 @@ class AllReduceRobust : public AllReduceBase { * \param root the root worker id to broadcast the data */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root); - /*! + /*! * \brief load latest check point * \param p_model pointer to the model - * \return true if there was stored checkpoint and load was successful - * false if there was no stored checkpoint, means we are start over gain - */ - virtual bool LoadCheckPoint(utils::ISerializable *p_model); + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * the p_model is not touched, user should do necessary initialization by themselves + * \sa CheckPoint, VersionNumber + */ + virtual int LoadCheckPoint(utils::ISerializable *p_model); /*! * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * * \param p_model pointer to the model + * \sa LoadCheckPoint, VersionNumber */ - virtual void CheckPoint(const utils::ISerializable &model); + virtual void CheckPoint(const utils::ISerializable &model); /*! * \brief explicitly re-init everything before calling LoadCheckPoint * call this function when IEngine throw an exception out, @@ -359,8 +366,7 @@ class AllReduceRobust : public AllReduceBase { // result buffer ResultBuffer resbuf; // last check point model - std::string checked_model; - + std::string checked_model; }; } // namespace engine } // namespace rabit diff --git a/src/rabit.h b/src/rabit.h index 635e3ff87..5659798ec 100644 --- a/src/rabit.h +++ b/src/rabit.h @@ -93,21 +93,43 @@ template inline void AllReduce(DType *sendrecvbuf, size_t count) { engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer); } -/*! +/*! * \brief load latest check point * \param p_model pointer to the model - * \return true if there was stored checkpoint and load was successful - * false if there was no stored checkpoint, means we are start over gain + * \return the version number of check point loaded + * if returned version == 0, this means no model has been CheckPointed + * the p_model is not touched, user should do necessary initialization by themselves + * + * Common usage example: + * int iter = rabit::LoadCheckPoint(&model); + * if (iter == 0) model.InitParameters(); + * for (i = iter; i < max_iter; ++i) { + * do many things, include allreduce + * rabit::CheckPoint(model); + * } + * + * \sa CheckPoint, VersionNumber */ -inline bool LoadCheckPoint(utils::ISerializable *p_model) { +inline int LoadCheckPoint(utils::ISerializable *p_model) { return engine::GetEngine()->LoadCheckPoint(p_model); } /*! * \brief checkpoint the model, meaning we finished a stage of execution + * every time we call check point, there is a version number which will increase by one + * * \param p_model pointer to the model + * \sa LoadCheckPoint, VersionNumber */ inline void CheckPoint(const utils::ISerializable &model) { engine::GetEngine()->CheckPoint(model); } +/*! + * \return version number of current stored model, + * which means how many calls to CheckPoint we made so far + * \sa LoadCheckPoint, CheckPoint + */ +inline int VersionNumber(void) { + return engine::GetEngine()->VersionNumber(); +} } // namespace rabit #endif // RABIT_ALLREDUCE_H diff --git a/test/Makefile b/test/Makefile index a3f6b07c7..a48fcd77c 100644 --- a/test/Makefile +++ b/test/Makefile @@ -11,7 +11,7 @@ else endif # specify tensor path -BIN = test_allreduce test_recover +BIN = test_allreduce test_recover test_model_recover OBJ = engine_base.o engine_robust.o engine.o .PHONY: clean all @@ -23,6 +23,7 @@ engine.o: ../src/engine.cc ../src/*.h engine_robust.o: ../src/engine_robust.cc ../src/*.h test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ) test_recover: test_recover.cpp ../src/*.h $(OBJ) +test_model_recover: test_model_recover.cpp ../src/*.h $(OBJ) $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) diff --git a/test/test_recover.cpp b/test/test_recover.cpp index 215177f20..761226889 100644 --- a/test/test_recover.cpp +++ b/test/test_recover.cpp @@ -70,18 +70,16 @@ inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) { // dummy model class Model : public rabit::utils::ISerializable { public: - // iterations - int iter; // load from stream virtual void Load(rabit::utils::IStream &fi) { - fi.Read(&iter, sizeof(iter)); + // do nothing } /*! \brief save the model to the stream */ virtual void Save(rabit::utils::IStream &fo) const { - fo.Write(&iter, sizeof(iter)); + // do nothing } virtual void InitModel(void) { - iter = 0; + // do nothing } }; @@ -101,7 +99,7 @@ int main(int argc, char *argv[]) { int ntrial = 0; while (true) { try { - if (!rabit::LoadCheckPoint(&model)) { + if (rabit::LoadCheckPoint(&model) == 0) { model.InitModel(); } utils::LogPrintf("[%d/%d] start at %s\n", rank, ntrial, name.c_str());