add local model in checkpoint interface, a new goal

This commit is contained in:
tqchen 2014-12-04 11:09:15 -08:00
parent 79e7862583
commit cc410b8c90
7 changed files with 248 additions and 36 deletions

View File

@ -348,7 +348,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
}
if (len != -1) {
size_down_in += static_cast<size_t>(len);
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}

View File

@ -84,23 +84,48 @@ class AllreduceBase : public IEngine {
}
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) {
return 0;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) {
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) {
version_number += 1;
}
/*!
@ -267,6 +292,8 @@ class AllreduceBase : public IEngine {
int parent_rank;
// sockets of all links
std::vector<LinkRecord> links;
// pointer to someplace in the ring
LinkRecord *ring_prev, *ring_next;
//----- meta information-----
// unique identifier of the possible job this process is doing
// used to assign ranks, optional, default to NULL

View File

@ -17,6 +17,7 @@ namespace rabit {
namespace engine {
AllreduceRobust::AllreduceRobust(void) {
result_buffer_round = 1;
num_local_replica = 2;
seq_counter = 0;
}
/*! \brief shutdown the engine */
@ -108,22 +109,38 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
}
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model) {
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
// check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// load from buffer
utils::MemoryBufferStream fs(&checked_model);
utils::MemoryBufferStream fs(&mglobal_model);
fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number;
p_model->Load(fs);
global_model->Load(fs);
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
@ -139,20 +156,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
void AllreduceRobust::CheckPoint(const utils::ISerializable &model) {
// increase version number
version_number += 1;
// save model
checked_model.resize(0);
utils::MemoryBufferStream fs(&checked_model);
fs.Write(&version_number, sizeof(version_number));
model.Save(fs);
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
"check point must return true");
// increase version number
version_number += 1;
// save model
mglobal_model.resize(0);
utils::MemoryBufferStream fs(&mglobal_model);
fs.Write(&version_number, sizeof(version_number));
global_model->Save(fs);
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
@ -488,6 +516,10 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
if (finished) break;
selecter.Select();
// exception handling
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
}
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
@ -548,16 +580,16 @@ AllreduceRobust::TryRecoverData(RecoverType role,
*/
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
RecoverType role = requester ? kRequestData : kHaveData;
size_t size = this->checked_model.length();
size_t size = this->mglobal_model.length();
int recv_link;
std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
if (role == kRequestData) {
checked_model.resize(size);
mglobal_model.resize(size);
}
if (size == 0) return kSuccess;
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
return TryRecoverData(role, &mglobal_model[0], size, recv_link, req_in);
}
/*!
* \brief try to get the result of operation specified by seqno
@ -674,6 +706,81 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
utils::Assert(false, "RecoverExec: should not reach here");
return true;
}
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
* sendrecvbuf[0:read_ptr] are already provided by current node
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
* current node will send sendrecvbuf[write_ptr:write_end] to next link
* write_ptr will wait till the data is readed before sending the data
* this function requires read_end >= write_end
*
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
* \param read_ptr the initial read pointer
* \param read_end the ending position to read
* \param write_ptr the initial write pointer
* \param write_end the ending position to write
* \param prev pointer to link to previous position in ring
* \param prev pointer to link of next position in ring
*/
AllreduceRobust::ReturnType
AllreduceRobust::RingPassing(void *sendrecvbuf_,
size_t read_ptr,
size_t read_end,
size_t write_ptr,
size_t write_end,
LinkRecord *prev_link,
LinkRecord *next_link) {
if (links.size() == 0 || read_end == 0) return kSuccess;
utils::Assert(read_end <= write_end, "boundary check");
utils::Assert(read_ptr <= read_end, "boundary check");
utils::Assert(write_ptr <= write_end, "boundary check");
// take reference
LinkRecord &prev = *prev_link, &next = *next_link;
// send recv buffer
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
if (read_ptr != read_end) {
selecter.WatchRead(prev.sock);
finished = false;
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
if (finished) break;
selecter.Select();
if (selecter.CheckExcept(prev.sock)) return kGetExcept;
if (selecter.CheckExcept(next.sock)) return kGetExcept;
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return kSockError;
}
if (len != -1) {
read_ptr += static_cast<size_t>(len);
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}
}
if (write_ptr != write_end && write_ptr < read_ptr &&
selecter.CheckWrite(next.sock)) {
size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
ssize_t len = next.sock.Send(buf + write_ptr, nsend);
if (len != -1) {
write_ptr += static_cast<size_t>(len);
} else {
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
}
}
}
return kSuccess;
}
} // namespace engine
} // namespace rabit

View File

@ -49,21 +49,46 @@ class AllreduceRobust : public AllreduceBase {
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
*
* Common usage example:
* int iter = rabit::LoadCheckPoint(&model);
* if (iter == 0) model.InitParameters();
* for (i = iter; i < max_iter; ++i) {
* do many things, include allreduce
* rabit::CheckPoint(model);
* }
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model);
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL);
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model);
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL);
/*!
* \brief explicitly re-init everything before calling LoadCheckPoint
* call this function when IEngine throw an exception out,
@ -259,7 +284,7 @@ class AllreduceRobust : public AllreduceBase {
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
/*!
* \brief try to load check point
*
@ -325,6 +350,30 @@ class AllreduceRobust : public AllreduceBase {
size_t size,
int recv_link,
const std::vector<bool> &req_in);
/*!
* \brief perform a ring passing to receive data from prev link, and sent data to next link
* this allows data to stream over a ring structure
* sendrecvbuf[0:read_ptr] are already provided by current node
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
* current node will send sendrecvbuf[write_ptr:write_end] to next link
* write_ptr will wait till the data is readed before sending the data
* this function requires read_end >= write_end
*
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
* \param read_ptr the initial read pointer
* \param read_end the ending position to read
* \param write_ptr the initial write pointer
* \param write_end the ending position to write
* \param prev pointer to link to previous position in ring
* \param prev pointer to link of next position in ring
*/
ReturnType RingPassing(void *senrecvbuf_,
size_t read_ptr,
size_t read_end,
size_t write_ptr,
size_t write_end,
LinkRecord *prev_link,
LinkRecord *next_link);
/*!
* \brief run message passing algorithm on the allreduce tree
* the result is edge message stored in p_edge_in and p_edge_out
@ -358,10 +407,21 @@ class AllreduceRobust : public AllreduceBase {
int seq_counter;
// the round of result buffer, used to mode the result
int result_buffer_round;
// result buffer
// result buffer of all reduce
ResultBuffer resbuf;
// last check point model
std::string checked_model;
// last check point global model
std::string mglobal_model;
// number of replica for local state/model
int num_local_replica;
// pointer to memory position in the local model
// local model is stored in CSR format(like a sparse matrices)
// local_model[rptr[0]:rptr[1]] stores the model of current node
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
std::vector<size_t> local_rptr;
// storage for local model replicas
std::string mlocal_model;
// temporal storage
std::string tmp_local_model;
};
} // namespace engine
} // namespace rabit

View File

@ -60,7 +60,12 @@ class IEngine {
virtual void InitAfterException(void) = 0;
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local model is needed
*
* \return the version number of check point loaded
* if returned version == 0, this means no model has been CheckPointed
* the p_model is not touched, user should do necessary initialization by themselves
@ -75,15 +80,26 @@ class IEngine {
*
* \sa CheckPoint, VersionNumber
*/
virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) = 0;
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* every time we call check point, there is a version number which will increase by one
*
* \param p_model pointer to the model
* \param global_model pointer to the globally shared model/state
* when calling this function, the caller need to gauranttees that global_model
* is the same in all nodes
* \param local_model pointer to local model, that is specific to current node/rank
* this can be NULL when no local state is needed
*
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
* So only CheckPoint with global_model if possible
*
* \sa LoadCheckPoint, VersionNumber
*/
virtual void CheckPoint(const utils::ISerializable &model) = 0;
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) = 0;
/*!
* \return version number of current stored model,
* which means how many calls to CheckPoint we made so far

View File

@ -32,10 +32,12 @@ class MPIEngine : public IEngine {
virtual void InitAfterException(void) {
utils::Error("MPI is not fault tolerant");
}
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
virtual int LoadCheckPoint(utils::ISerializable *global_model,
utils::ISerializable *local_model = NULL) {
return 0;
}
virtual void CheckPoint(const utils::ISerializable &model) {
virtual void CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model = NULL) {
version_number += 1;
}
virtual int VersionNumber(void) const {

View File

@ -129,7 +129,7 @@ inline int LoadCheckPoint(utils::ISerializable *p_model) {
}
// checkpoint the model, meaning we finished a stage of execution
inline void CheckPoint(const utils::ISerializable &model) {
engine::GetEngine()->CheckPoint(model);
engine::GetEngine()->CheckPoint(&model);
}
// return the version number of currently stored model
inline int VersionNumber(void) {