change in interface, seems resetlink is still bad

This commit is contained in:
tqchen 2014-12-01 21:39:51 -08:00
parent b76cd5858c
commit 255218a2f3
8 changed files with 128 additions and 40 deletions

View File

@ -61,15 +61,35 @@ class IEngine {
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param p_model pointer to the model * \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful * \return the version number of check point loaded
* false if there was no stored checkpoint, means we are start over gain * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/ */
virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0; virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model * \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable &model) = 0; virtual void CheckPoint(const utils::ISerializable &model) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief get rank of current node */ /*! \brief get rank of current node */
virtual int GetRank(void) const = 0; virtual int GetRank(void) const = 0;
/*! \brief get total number of */ /*! \brief get total number of */

View File

@ -21,6 +21,7 @@ AllReduceBase::AllReduceBase(void) {
nport_trial = 1000; nport_trial = 1000;
rank = 0; rank = 0;
world_size = 1; world_size = 1;
version_number = 0;
this->SetParam("reduce_buffer", "256MB"); this->SetParam("reduce_buffer", "256MB");
} }

View File

@ -35,10 +35,10 @@ class AllReduceBase : public IEngine {
// constant one byte out of band message to indicate error happening // constant one byte out of band message to indicate error happening
AllReduceBase(void); AllReduceBase(void);
virtual ~AllReduceBase(void) {} virtual ~AllReduceBase(void) {}
// shutdown the engine
void Shutdown(void);
// initialize the manager // initialize the manager
void Init(void); void Init(void);
// shutdown the engine
virtual void Shutdown(void);
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
@ -82,20 +82,34 @@ class AllReduceBase : public IEngine {
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"AllReduce failed"); "AllReduce failed");
} }
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param p_model pointer to the model * \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful * \return the version number of check point loaded
* false if there was no stored checkpoint, means we are start over gain * if returned version == 0, this means no model has been CheckPointed
*/ * the p_model is not touched, user should do necessary initialization by themselves
virtual bool LoadCheckPoint(utils::ISerializable *p_model) { * \sa CheckPoint, VersionNumber
return false; */
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
return 0;
} }
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model * \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable &model) { virtual void CheckPoint(const utils::ISerializable &model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
return version_number;
} }
/*! /*!
* \brief explicitly re-init everything before calling LoadCheckPoint * \brief explicitly re-init everything before calling LoadCheckPoint
@ -236,6 +250,8 @@ class AllReduceBase : public IEngine {
* \sa ReturnType * \sa ReturnType
*/ */
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root); ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
//---- data structure related to model ----
int version_number;
//---- local data related to link ---- //---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree // index of parent link, can be -1, meaning this is root of the tree
int parent_index; int parent_index;

View File

@ -16,9 +16,27 @@
namespace rabit { namespace rabit {
namespace engine { namespace engine {
AllReduceRobust::AllReduceRobust(void) { AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 2; result_buffer_round = 1;
seq_counter = 0; seq_counter = 0;
} }
/*! \brief shutdown the engine */
void AllReduceRobust::Shutdown(void) {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
AllReduceBase::Shutdown();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllReduceRobust::SetParam(const char *name, const char *val) { void AllReduceRobust::SetParam(const char *name, const char *val) {
AllReduceBase::SetParam(name, val); AllReduceBase::SetParam(name, val);
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val); if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
@ -91,24 +109,25 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param p_model pointer to the model * \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful * \return the version number of check point loaded
* false if there was no stored checkpoint, means we are start over gain * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/ */
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) { int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// check if we succesfll // check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) { if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// if loaded model is empty, this simply means we did not call checkpoint yet
// ask caller to reinit model
if (checked_model.length() == 0) return false;
// load from buffer // load from buffer
utils::MemoryBufferStream fs(&checked_model); utils::MemoryBufferStream fs(&checked_model);
fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number;
p_model->Load(fs); p_model->Load(fs);
// run another phase of check ack, if recovered from data // run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true"); "check ack must return true");
return true; return version_number;
} else { } else {
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
@ -118,14 +137,19 @@ bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
} }
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model * \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/ */
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) { void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
// increase version number
version_number += 1;
// save model // save model
checked_model.resize(0); checked_model.resize(0);
utils::MemoryBufferStream fs(&checked_model); utils::MemoryBufferStream fs(&checked_model);
fs.Write(&version_number, sizeof(version_number));
model.Save(fs); model.Save(fs);
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
// execute checkpoint, note: when checkpoint existing, load will not happen // execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true"); "check point must return true");
@ -586,7 +610,7 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations"); utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
} }
// request // request
ActionSummary req(flag, seqno); ActionSummary req(flag, seqno);
while (true) { while (true) {
// action // action
ActionSummary act = req; ActionSummary act = req;

View File

@ -20,6 +20,8 @@ class AllReduceRobust : public AllReduceBase {
public: public:
AllReduceRobust(void); AllReduceRobust(void);
virtual ~AllReduceRobust(void) {} virtual ~AllReduceRobust(void) {}
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*! /*!
* \brief set parameters to the engine * \brief set parameters to the engine
* \param name parameter name * \param name parameter name
@ -45,18 +47,23 @@ class AllReduceRobust : public AllReduceBase {
* \param root the root worker id to broadcast the data * \param root the root worker id to broadcast the data
*/ */
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root); virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param p_model pointer to the model * \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful * \return the version number of check point loaded
* false if there was no stored checkpoint, means we are start over gain * if returned version == 0, this means no model has been CheckPointed
*/ * the p_model is not touched, user should do necessary initialization by themselves
virtual bool LoadCheckPoint(utils::ISerializable *p_model); * \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model);
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model * \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/ */
virtual void CheckPoint(const utils::ISerializable &model); virtual void CheckPoint(const utils::ISerializable &model);
/*! /*!
* \brief explicitly re-init everything before calling LoadCheckPoint * \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out, * call this function when IEngine throw an exception out,
@ -359,8 +366,7 @@ class AllReduceRobust : public AllReduceBase {
// result buffer // result buffer
ResultBuffer resbuf; ResultBuffer resbuf;
// last check point model // last check point model
std::string checked_model; std::string checked_model;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit

View File

@ -93,21 +93,43 @@ template<typename OP, typename DType>
inline void AllReduce(DType *sendrecvbuf, size_t count) { inline void AllReduce(DType *sendrecvbuf, size_t count) {
engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>); engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>);
} }
/*! /*!
* \brief load latest check point * \brief load latest check point
* \param p_model pointer to the model * \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful * \return the version number of check point loaded
* false if there was no stored checkpoint, means we are start over gain * if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/ */
inline bool LoadCheckPoint(utils::ISerializable *p_model) { inline int LoadCheckPoint(utils::ISerializable *p_model) {
return engine::GetEngine()->LoadCheckPoint(p_model); return engine::GetEngine()->LoadCheckPoint(p_model);
} }
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model * \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/ */
inline void CheckPoint(const utils::ISerializable &model) { inline void CheckPoint(const utils::ISerializable &model) {
engine::GetEngine()->CheckPoint(model); engine::GetEngine()->CheckPoint(model);
} }
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();
}
} // namespace rabit } // namespace rabit
#endif // RABIT_ALLREDUCE_H #endif // RABIT_ALLREDUCE_H

View File

@ -11,7 +11,7 @@ else
endif endif
# specify tensor path # specify tensor path
BIN = test_allreduce test_recover BIN = test_allreduce test_recover test_model_recover
OBJ = engine_base.o engine_robust.o engine.o OBJ = engine_base.o engine_robust.o engine.o
.PHONY: clean all .PHONY: clean all
@ -23,6 +23,7 @@ engine.o: ../src/engine.cc ../src/*.h
engine_robust.o: ../src/engine_robust.cc ../src/*.h engine_robust.o: ../src/engine_robust.cc ../src/*.h
test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ) test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ)
test_recover: test_recover.cpp ../src/*.h $(OBJ) test_recover: test_recover.cpp ../src/*.h $(OBJ)
test_model_recover: test_model_recover.cpp ../src/*.h $(OBJ)
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^) $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)

View File

@ -70,18 +70,16 @@ inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) {
// dummy model // dummy model
class Model : public rabit::utils::ISerializable { class Model : public rabit::utils::ISerializable {
public: public:
// iterations
int iter;
// load from stream // load from stream
virtual void Load(rabit::utils::IStream &fi) { virtual void Load(rabit::utils::IStream &fi) {
fi.Read(&iter, sizeof(iter)); // do nothing
} }
/*! \brief save the model to the stream */ /*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const { virtual void Save(rabit::utils::IStream &fo) const {
fo.Write(&iter, sizeof(iter)); // do nothing
} }
virtual void InitModel(void) { virtual void InitModel(void) {
iter = 0; // do nothing
} }
}; };
@ -101,7 +99,7 @@ int main(int argc, char *argv[]) {
int ntrial = 0; int ntrial = 0;
while (true) { while (true) {
try { try {
if (!rabit::LoadCheckPoint(&model)) { if (rabit::LoadCheckPoint(&model) == 0) {
model.InitModel(); model.InitModel();
} }
utils::LogPrintf("[%d/%d] start at %s\n", rank, ntrial, name.c_str()); utils::LogPrintf("[%d/%d] start at %s\n", rank, ntrial, name.c_str());