before make rabit public

This commit is contained in:
tqchen 2014-12-04 17:30:58 -08:00
parent cc410b8c90
commit 821eb21ae2
3 changed files with 14 additions and 14 deletions

View File

@ -2,7 +2,7 @@
rabit is a light weight library that provides a fault tolerant interface of Allreduce and Broadcast. It is designed to support easy implementation of distributed machine learning programs, many of which sits naturally under Allreduce abstraction. rabit is a light weight library that provides a fault tolerant interface of Allreduce and Broadcast. It is designed to support easy implementation of distributed machine learning programs, many of which sits naturally under Allreduce abstraction.
Contributors: https://github.com/tqchen/rabit/graphs/contributors Interface: [rabit.h](src/rabit.h)
Features Features
==== ====
@ -27,4 +27,3 @@ Design Goal
* rabit should run fast * rabit should run fast
* rabit is light weight * rabit is light weight
* rabit dig safe burrows to avoid disasters * rabit dig safe burrows to avoid disasters

View File

@ -137,7 +137,7 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
// reset result buffer // reset result buffer
resbuf.Clear(); seq_counter = 0; resbuf.Clear(); seq_counter = 0;
// load from buffer // load from buffer
utils::MemoryBufferStream fs(&mglobal_model); utils::MemoryBufferStream fs(&global_checkpoint);
fs.Read(&version_number, sizeof(version_number)); fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number; if (version_number == 0) return version_number;
global_model->Load(fs); global_model->Load(fs);
@ -155,7 +155,7 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
/*! /*!
* \brief checkpoint the model, meaning we finished a stage of execution * \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one * every time we call check point, there is a version number which will increase by one
* *
* \param global_model pointer to the globally shared model/state * \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model * when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes * is the same in all nodes
@ -174,11 +174,12 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
// execute checkpoint, note: when checkpoint existing, load will not happen // execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq), utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true"); "check point must return true");
// this is the critical region where we will change all the stored models
// increase version number // increase version number
version_number += 1; version_number += 1;
// save model // save model
mglobal_model.resize(0); global_checkpoint.resize(0);
utils::MemoryBufferStream fs(&mglobal_model); utils::MemoryBufferStream fs(&global_checkpoint);
fs.Write(&version_number, sizeof(version_number)); fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs); global_model->Save(fs);
// reset result buffer // reset result buffer
@ -580,16 +581,16 @@ AllreduceRobust::TryRecoverData(RecoverType role,
*/ */
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) { AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
RecoverType role = requester ? kRequestData : kHaveData; RecoverType role = requester ? kRequestData : kHaveData;
size_t size = this->mglobal_model.length(); size_t size = this->global_checkpoint.length();
int recv_link; int recv_link;
std::vector<bool> req_in; std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ; if (succ != kSuccess) return succ;
if (role == kRequestData) { if (role == kRequestData) {
mglobal_model.resize(size); global_checkpoint.resize(size);
} }
if (size == 0) return kSuccess; if (size == 0) return kSuccess;
return TryRecoverData(role, &mglobal_model[0], size, recv_link, req_in); return TryRecoverData(role, &global_checkpoint[0], size, recv_link, req_in);
} }
/*! /*!
* \brief try to get the result of operation specified by seqno * \brief try to get the result of operation specified by seqno

View File

@ -349,7 +349,7 @@ class AllreduceRobust : public AllreduceBase {
void *sendrecvbuf_, void *sendrecvbuf_,
size_t size, size_t size,
int recv_link, int recv_link,
const std::vector<bool> &req_in); const std::vector<bool> &req_in);
/*! /*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link * \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure * this allows data to stream over a ring structure
@ -410,7 +410,7 @@ class AllreduceRobust : public AllreduceBase {
// result buffer of all reduce // result buffer of all reduce
ResultBuffer resbuf; ResultBuffer resbuf;
// last check point global model // last check point global model
std::string mglobal_model; std::string global_checkpoint;
// number of replica for local state/model // number of replica for local state/model
int num_local_replica; int num_local_replica;
// pointer to memory position in the local model // pointer to memory position in the local model
@ -419,9 +419,9 @@ class AllreduceRobust : public AllreduceBase {
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring // local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
std::vector<size_t> local_rptr; std::vector<size_t> local_rptr;
// storage for local model replicas // storage for local model replicas
std::string mlocal_model; std::string local_checkpoint;
// temporal storage // temporal storage for doing local checkpointing
std::string tmp_local_model; std::string tmp_local_check;
}; };
} // namespace engine } // namespace engine
} // namespace rabit } // namespace rabit