change in interface, seems resetlink is still bad

This commit is contained in:
tqchen 2014-12-01 21:39:51 -08:00
parent b76cd5858c
commit 255218a2f3
8 changed files with 128 additions and 40 deletions

View File

@ -61,15 +61,35 @@ class IEngine {
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model) = 0;
virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const = 0;
/*! \brief get rank of current node */
virtual int GetRank(void) const = 0;
/*! \brief get total number of */

View File

@ -21,6 +21,7 @@ AllReduceBase::AllReduceBase(void) {
nport_trial = 1000;
rank = 0;
world_size = 1;
version_number = 0;
this->SetParam("reduce_buffer", "256MB");
}

View File

@ -35,10 +35,10 @@ class AllReduceBase : public IEngine {
// constant one byte out of band message to indicate error happening
AllReduceBase(void);
virtual ~AllReduceBase(void) {}
// shutdown the engine
void Shutdown(void);
// initialize the manager
void Init(void);
// shutdown the engine
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -85,17 +85,31 @@ class AllReduceBase : public IEngine {
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model) {
return false;
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
return 0;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) {
version_number += 1;
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
virtual int VersionNumber(void) const {
return version_number;
}
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
@ -236,6 +250,8 @@ class AllReduceBase : public IEngine {
* \sa ReturnType
*/
ReturnType TryBroadcast(void *sendrecvbuf_, size_t size, int root);
//---- data structure related to model ----
int version_number;
//---- local data related to link ----
// index of parent link, can be -1, meaning this is root of the tree
int parent_index;

View File

@ -16,9 +16,27 @@
namespace rabit {
namespace engine {
AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 2;
result_buffer_round = 1;
seq_counter = 0;
}
/*! \brief shutdown the engine */
void AllReduceRobust::Shutdown(void) {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
AllReduceBase::Shutdown();
}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllReduceRobust::SetParam(const char *name, const char *val) {
AllReduceBase::SetParam(name, val);
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
@ -91,24 +109,25 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
int AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// if loaded model is empty, this simply means we did not call checkpoint yet
// ask caller to reinit model
if (checked_model.length() == 0) return false;
// load from buffer
utils::MemoryBufferStream fs(&checked_model);
fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number;
p_model->Load(fs);
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
return true;
return version_number;
} else {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
@ -118,14 +137,19 @@ bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
// increase version number
version_number += 1;
// save model
checked_model.resize(0);
utils::MemoryBufferStream fs(&checked_model);
fs.Write(&version_number, sizeof(version_number));
model.Save(fs);
utils::Check(checked_model.length() != 0, "CheckPoint: empty model, model.Save must save something");
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");

View File

@ -20,6 +20,8 @@ class AllReduceRobust : public AllReduceBase {
public:
AllReduceRobust(void);
virtual ~AllReduceRobust(void) {}
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -48,13 +50,18 @@ class AllReduceRobust : public AllReduceBase {
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
* \sa CheckPoint, VersionNumber
*/
virtual bool LoadCheckPoint(utils::ISerializable *p_model);
virtual int LoadCheckPoint(utils::ISerializable *p_model);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model);
/*!
@ -360,7 +367,6 @@ class AllReduceRobust : public AllReduceBase {
ResultBuffer resbuf;
// last check point model
std::string checked_model;
};
} // namespace engine
} // namespace rabit

View File

@ -96,18 +96,40 @@ inline void AllReduce(DType *sendrecvbuf, size_t count) {
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
inline bool LoadCheckPoint(utils::ISerializable *p_model) {
inline int LoadCheckPoint(utils::ISerializable *p_model) {
return engine::GetEngine()->LoadCheckPoint(p_model);
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \sa LoadCheckPoint, VersionNumber
*/
inline void CheckPoint(const utils::ISerializable &model) {
engine::GetEngine()->CheckPoint(model);
}
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far
* \sa LoadCheckPoint, CheckPoint
*/
inline int VersionNumber(void) {
return engine::GetEngine()->VersionNumber();
}
} // namespace rabit
#endif // RABIT_ALLREDUCE_H

View File

@ -11,7 +11,7 @@ else
endif
# specify tensor path
BIN = test_allreduce test_recover
BIN = test_allreduce test_recover test_model_recover
OBJ = engine_base.o engine_robust.o engine.o
.PHONY: clean all
@ -23,6 +23,7 @@ engine.o: ../src/engine.cc ../src/*.h
engine_robust.o: ../src/engine_robust.cc ../src/*.h
test_allreduce: test_allreduce.cpp ../src/*.h $(OBJ)
test_recover: test_recover.cpp ../src/*.h $(OBJ)
test_model_recover: test_model_recover.cpp ../src/*.h $(OBJ)
$(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc, $^)

View File

@ -70,18 +70,16 @@ inline void TestBcast(test::Mock &mock, size_t n, int root, int ntrial) {
// dummy model
class Model : public rabit::utils::ISerializable {
public:
// iterations
int iter;
// load from stream
virtual void Load(rabit::utils::IStream &fi) {
fi.Read(&iter, sizeof(iter));
// do nothing
}
/*! \brief save the model to the stream */
virtual void Save(rabit::utils::IStream &fo) const {
fo.Write(&iter, sizeof(iter));
// do nothing
}
virtual void InitModel(void) {
iter = 0;
// do nothing
}
};
@ -101,7 +99,7 @@ int main(int argc, char *argv[]) {
int ntrial = 0;
while (true) {
try {
if (!rabit::LoadCheckPoint(&model)) {
if (rabit::LoadCheckPoint(&model) == 0) {
model.InitModel();
}
utils::LogPrintf("[%d/%d] start at %s\n", rank, ntrial, name.c_str());