change in interface, seems resetlink is still bad

This commit is contained in:
tqchen
2014-12-01 21:39:51 -08:00
parent b76cd5858c
commit 255218a2f3
8 changed files with 128 additions and 40 deletions

View File

@@ -61,15 +61,35 @@ class IEngine {
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0;
virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief get rank of current node */
virtual int GetRank(void) const = 0;
/*! \brief get total number of */

View File

@@ -21,6 +21,7 @@ AllReduceBase::AllReduceBase(void) {
nport_trial = 1000;
rank = 0;
world_size = 1;
version_number = 0;
this->SetParam("reduce_buffer", "256MB");
}

View File

@@ -35,10 +35,10 @@ class AllReduceBase : public IEngine {
// constant one byte out of band message to indicate error happening
AllReduceBase(void);
virtual ~AllReduceBase(void) {}
// shutdown the engine
void Shutdown(void);
// initialize the manager
void Init(void);
// shutdown the engine
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@@ -82,20 +82,34 @@ class AllReduceBase : public IEngine {
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"AllReduce failed");
}
/*!
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
return false;
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
return 0;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
return version_number;
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
@@ -236,6 +250,8 @@ class AllReduceBase : public IEngine {
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
//---- data structure related to model ----
int version_number;
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;

View File

@@ -16,9 +16,27 @@
namespace rabit {
namespace engine {
AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 2;
result_buffer_round = 1;
seq_counter = 0;
}
/*! \brief shutdown the engine */
void AllReduceRobust::Shutdown(void) {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
AllReduceBase::Shutdown();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllReduceRobust::SetParam(const char *name, const char *val) {
AllReduceBase::SetParam(name, val);
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
@@ -91,24 +109,25 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// if loaded model is empty, this simply means we did not call checkpoint yet
// ask caller to reinit model
if (checked_model.length() == 0) return false;
// load from buffer
utils::MemoryBufferStream fs(&checked_model);
fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number;
p_model->Load(fs);
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
return true;
return version_number;
} else {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
@@ -118,14 +137,19 @@ bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
// increase version number
version_number += 1;
// save model
checked_model.resize(0);
utils::MemoryBufferStream fs(&checked_model);
fs.Write(&version_number, sizeof(version_number));
model.Save(fs);
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
@@ -586,7 +610,7 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
}
// request
ActionSummary req(flag, seqno);
ActionSummary req(flag, seqno);
while (true) {
// action
ActionSummary act = req;

View File

@@ -20,6 +20,8 @@ class AllReduceRobust : public AllReduceBase {
public:
AllReduceRobust(void);
virtual ~AllReduceRobust(void) {}
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@@ -45,18 +47,23 @@ class AllReduceRobust : public AllReduceBase {
* \param root the root worker id to broadcast the data
*/
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
/*!
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model);
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model);
virtual void CheckPoint(const utils::ISerializable &model);
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
@@ -359,8 +366,7 @@ class AllReduceRobust : public AllReduceBase {
// result buffer
ResultBuffer resbuf;
// last check point model
std::string checked_model;
std::string checked_model;
};
} // namespace engine
} // namespace rabit

View File

@@ -93,21 +93,43 @@ template<typename OP, typename DType>
inline void AllReduce(DType *sendrecvbuf, size_t count) {
engine::GetEngine()->AllReduce(sendrecvbuf, sizeof(DType), count, op::Reducer<OP,DType>);
}
/*!
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
inline int LoadCheckPoint(utils::ISerializable *p_model) {
return engine::GetEngine()->LoadCheckPoint(p_model);
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
inline void CheckPoint(const utils::ISerializable &model) {
engine::GetEngine()->CheckPoint(model);
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();
}
} // namespace rabit
#endif // RABIT_ALLREDUCE_H