add local model in checkpoint interface, a new goal
This commit is contained in:
parent
79e7862583
commit
cc410b8c90
@ -348,7 +348,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
|||||||
}
|
}
|
||||||
if (len != -1) {
|
if (len != -1) {
|
||||||
size_down_in += static_cast<size_t>(len);
|
size_down_in += static_cast<size_t>(len);
|
||||||
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
||||||
} else {
|
} else {
|
||||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -84,23 +84,48 @@ class AllreduceBase : public IEngine {
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local model is needed
|
||||||
|
*
|
||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
|
*
|
||||||
|
* Common usage example:
|
||||||
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
|
* if (iter == 0) model.InitParameters();
|
||||||
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
|
* do many things, include allreduce
|
||||||
|
* rabit::CheckPoint(model);
|
||||||
|
* }
|
||||||
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
|
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
|
utils::ISerializable *local_model = NULL) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
*
|
||||||
|
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||||
|
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||||
|
* So only CheckPoint with global_model if possible
|
||||||
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||||
|
const utils::ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -267,6 +292,8 @@ class AllreduceBase : public IEngine {
|
|||||||
int parent_rank;
|
int parent_rank;
|
||||||
// sockets of all links
|
// sockets of all links
|
||||||
std::vector<LinkRecord> links;
|
std::vector<LinkRecord> links;
|
||||||
|
// pointer to someplace in the ring
|
||||||
|
LinkRecord *ring_prev, *ring_next;
|
||||||
//----- meta information-----
|
//----- meta information-----
|
||||||
// unique identifier of the possible job this process is doing
|
// unique identifier of the possible job this process is doing
|
||||||
// used to assign ranks, optional, default to NULL
|
// used to assign ranks, optional, default to NULL
|
||||||
|
|||||||
@ -17,6 +17,7 @@ namespace rabit {
|
|||||||
namespace engine {
|
namespace engine {
|
||||||
AllreduceRobust::AllreduceRobust(void) {
|
AllreduceRobust::AllreduceRobust(void) {
|
||||||
result_buffer_round = 1;
|
result_buffer_round = 1;
|
||||||
|
num_local_replica = 2;
|
||||||
seq_counter = 0;
|
seq_counter = 0;
|
||||||
}
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
@ -108,22 +109,38 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
|||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local model is needed
|
||||||
|
*
|
||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
|
*
|
||||||
|
* Common usage example:
|
||||||
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
|
* if (iter == 0) model.InitParameters();
|
||||||
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
|
* do many things, include allreduce
|
||||||
|
* rabit::CheckPoint(model);
|
||||||
|
* }
|
||||||
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
|
utils::ISerializable *local_model) {
|
||||||
|
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
|
||||||
// check if we succesfll
|
// check if we succesfll
|
||||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// load from buffer
|
// load from buffer
|
||||||
utils::MemoryBufferStream fs(&checked_model);
|
utils::MemoryBufferStream fs(&mglobal_model);
|
||||||
fs.Read(&version_number, sizeof(version_number));
|
fs.Read(&version_number, sizeof(version_number));
|
||||||
if (version_number == 0) return version_number;
|
if (version_number == 0) return version_number;
|
||||||
p_model->Load(fs);
|
global_model->Load(fs);
|
||||||
// run another phase of check ack, if recovered from data
|
// run another phase of check ack, if recovered from data
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||||
"check ack must return true");
|
"check ack must return true");
|
||||||
@ -139,20 +156,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
|||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
*
|
||||||
|
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||||
|
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||||
|
* So only CheckPoint with global_model if possible
|
||||||
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
void AllreduceRobust::CheckPoint(const utils::ISerializable &model) {
|
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||||
// increase version number
|
const utils::ISerializable *local_model) {
|
||||||
version_number += 1;
|
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
|
||||||
// save model
|
|
||||||
checked_model.resize(0);
|
|
||||||
utils::MemoryBufferStream fs(&checked_model);
|
|
||||||
fs.Write(&version_number, sizeof(version_number));
|
|
||||||
model.Save(fs);
|
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
||||||
"check point must return true");
|
"check point must return true");
|
||||||
|
// increase version number
|
||||||
|
version_number += 1;
|
||||||
|
// save model
|
||||||
|
mglobal_model.resize(0);
|
||||||
|
utils::MemoryBufferStream fs(&mglobal_model);
|
||||||
|
fs.Write(&version_number, sizeof(version_number));
|
||||||
|
global_model->Save(fs);
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// execute check ack step, load happens here
|
// execute check ack step, load happens here
|
||||||
@ -488,6 +516,10 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
selecter.Select();
|
selecter.Select();
|
||||||
|
// exception handling
|
||||||
|
for (int i = 0; i < nlink; ++i) {
|
||||||
|
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||||
|
}
|
||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
const int pid = recv_link;
|
const int pid = recv_link;
|
||||||
if (selecter.CheckRead(links[pid].sock)) {
|
if (selecter.CheckRead(links[pid].sock)) {
|
||||||
@ -548,16 +580,16 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
*/
|
*/
|
||||||
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||||
RecoverType role = requester ? kRequestData : kHaveData;
|
RecoverType role = requester ? kRequestData : kHaveData;
|
||||||
size_t size = this->checked_model.length();
|
size_t size = this->mglobal_model.length();
|
||||||
int recv_link;
|
int recv_link;
|
||||||
std::vector<bool> req_in;
|
std::vector<bool> req_in;
|
||||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||||
if (succ != kSuccess) return succ;
|
if (succ != kSuccess) return succ;
|
||||||
if (role == kRequestData) {
|
if (role == kRequestData) {
|
||||||
checked_model.resize(size);
|
mglobal_model.resize(size);
|
||||||
}
|
}
|
||||||
if (size == 0) return kSuccess;
|
if (size == 0) return kSuccess;
|
||||||
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
|
return TryRecoverData(role, &mglobal_model[0], size, recv_link, req_in);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief try to get the result of operation specified by seqno
|
* \brief try to get the result of operation specified by seqno
|
||||||
@ -674,6 +706,81 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|||||||
utils::Assert(false, "RecoverExec: should not reach here");
|
utils::Assert(false, "RecoverExec: should not reach here");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||||
|
* this allows data to stream over a ring structure
|
||||||
|
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||||
|
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||||
|
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||||
|
* write_ptr will wait till the data is readed before sending the data
|
||||||
|
* this function requires read_end >= write_end
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||||
|
* \param read_ptr the initial read pointer
|
||||||
|
* \param read_end the ending position to read
|
||||||
|
* \param write_ptr the initial write pointer
|
||||||
|
* \param write_end the ending position to write
|
||||||
|
* \param prev pointer to link to previous position in ring
|
||||||
|
* \param prev pointer to link of next position in ring
|
||||||
|
*/
|
||||||
|
AllreduceRobust::ReturnType
|
||||||
|
AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||||
|
size_t read_ptr,
|
||||||
|
size_t read_end,
|
||||||
|
size_t write_ptr,
|
||||||
|
size_t write_end,
|
||||||
|
LinkRecord *prev_link,
|
||||||
|
LinkRecord *next_link) {
|
||||||
|
if (links.size() == 0 || read_end == 0) return kSuccess;
|
||||||
|
utils::Assert(read_end <= write_end, "boundary check");
|
||||||
|
utils::Assert(read_ptr <= read_end, "boundary check");
|
||||||
|
utils::Assert(write_ptr <= write_end, "boundary check");
|
||||||
|
// take reference
|
||||||
|
LinkRecord &prev = *prev_link, &next = *next_link;
|
||||||
|
// send recv buffer
|
||||||
|
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||||
|
while (true) {
|
||||||
|
bool finished = true;
|
||||||
|
utils::SelectHelper selecter;
|
||||||
|
if (read_ptr != read_end) {
|
||||||
|
selecter.WatchRead(prev.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
if (write_ptr < read_ptr && write_ptr != write_end) {
|
||||||
|
selecter.WatchWrite(next.sock);
|
||||||
|
finished = false;
|
||||||
|
}
|
||||||
|
selecter.WatchException(prev.sock);
|
||||||
|
selecter.WatchException(next.sock);
|
||||||
|
if (finished) break;
|
||||||
|
selecter.Select();
|
||||||
|
if (selecter.CheckExcept(prev.sock)) return kGetExcept;
|
||||||
|
if (selecter.CheckExcept(next.sock)) return kGetExcept;
|
||||||
|
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
|
||||||
|
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
||||||
|
if (len == 0) {
|
||||||
|
prev.sock.Close(); return kSockError;
|
||||||
|
}
|
||||||
|
if (len != -1) {
|
||||||
|
read_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (write_ptr != write_end && write_ptr < read_ptr &&
|
||||||
|
selecter.CheckWrite(next.sock)) {
|
||||||
|
size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
|
||||||
|
ssize_t len = next.sock.Send(buf + write_ptr, nsend);
|
||||||
|
if (len != -1) {
|
||||||
|
write_ptr += static_cast<size_t>(len);
|
||||||
|
} else {
|
||||||
|
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
}
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|
||||||
|
|||||||
@ -49,21 +49,46 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local model is needed
|
||||||
|
*
|
||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
|
*
|
||||||
|
* Common usage example:
|
||||||
|
* int iter = rabit::LoadCheckPoint(&model);
|
||||||
|
* if (iter == 0) model.InitParameters();
|
||||||
|
* for (i = iter; i < max_iter; ++i) {
|
||||||
|
* do many things, include allreduce
|
||||||
|
* rabit::CheckPoint(model);
|
||||||
|
* }
|
||||||
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *p_model);
|
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
|
utils::ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
*
|
||||||
|
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||||
|
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||||
|
* So only CheckPoint with global_model if possible
|
||||||
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable &model);
|
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||||
|
const utils::ISerializable *local_model = NULL);
|
||||||
/*!
|
/*!
|
||||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||||
* call this function when IEngine throw an exception out,
|
* call this function when IEngine throw an exception out,
|
||||||
@ -259,7 +284,7 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
* result by recovering procedure, the action is complete, no further action is needed
|
* result by recovering procedure, the action is complete, no further action is needed
|
||||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||||
*/
|
*/
|
||||||
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
|
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
|
||||||
/*!
|
/*!
|
||||||
* \brief try to load check point
|
* \brief try to load check point
|
||||||
*
|
*
|
||||||
@ -325,6 +350,30 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
size_t size,
|
size_t size,
|
||||||
int recv_link,
|
int recv_link,
|
||||||
const std::vector<bool> &req_in);
|
const std::vector<bool> &req_in);
|
||||||
|
/*!
|
||||||
|
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||||
|
* this allows data to stream over a ring structure
|
||||||
|
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||||
|
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||||
|
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||||
|
* write_ptr will wait till the data is readed before sending the data
|
||||||
|
* this function requires read_end >= write_end
|
||||||
|
*
|
||||||
|
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||||
|
* \param read_ptr the initial read pointer
|
||||||
|
* \param read_end the ending position to read
|
||||||
|
* \param write_ptr the initial write pointer
|
||||||
|
* \param write_end the ending position to write
|
||||||
|
* \param prev pointer to link to previous position in ring
|
||||||
|
* \param prev pointer to link of next position in ring
|
||||||
|
*/
|
||||||
|
ReturnType RingPassing(void *senrecvbuf_,
|
||||||
|
size_t read_ptr,
|
||||||
|
size_t read_end,
|
||||||
|
size_t write_ptr,
|
||||||
|
size_t write_end,
|
||||||
|
LinkRecord *prev_link,
|
||||||
|
LinkRecord *next_link);
|
||||||
/*!
|
/*!
|
||||||
* \brief run message passing algorithm on the allreduce tree
|
* \brief run message passing algorithm on the allreduce tree
|
||||||
* the result is edge message stored in p_edge_in and p_edge_out
|
* the result is edge message stored in p_edge_in and p_edge_out
|
||||||
@ -358,10 +407,21 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
int seq_counter;
|
int seq_counter;
|
||||||
// the round of result buffer, used to mode the result
|
// the round of result buffer, used to mode the result
|
||||||
int result_buffer_round;
|
int result_buffer_round;
|
||||||
// result buffer
|
// result buffer of all reduce
|
||||||
ResultBuffer resbuf;
|
ResultBuffer resbuf;
|
||||||
// last check point model
|
// last check point global model
|
||||||
std::string checked_model;
|
std::string mglobal_model;
|
||||||
|
// number of replica for local state/model
|
||||||
|
int num_local_replica;
|
||||||
|
// pointer to memory position in the local model
|
||||||
|
// local model is stored in CSR format(like a sparse matrices)
|
||||||
|
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||||
|
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
||||||
|
std::vector<size_t> local_rptr;
|
||||||
|
// storage for local model replicas
|
||||||
|
std::string mlocal_model;
|
||||||
|
// temporal storage
|
||||||
|
std::string tmp_local_model;
|
||||||
};
|
};
|
||||||
} // namespace engine
|
} // namespace engine
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|||||||
24
src/engine.h
24
src/engine.h
@ -60,7 +60,12 @@ class IEngine {
|
|||||||
virtual void InitAfterException(void) = 0;
|
virtual void InitAfterException(void) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief load latest check point
|
* \brief load latest check point
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local model is needed
|
||||||
|
*
|
||||||
* \return the version number of check point loaded
|
* \return the version number of check point loaded
|
||||||
* if returned version == 0, this means no model has been CheckPointed
|
* if returned version == 0, this means no model has been CheckPointed
|
||||||
* the p_model is not touched, user should do necessary initialization by themselves
|
* the p_model is not touched, user should do necessary initialization by themselves
|
||||||
@ -75,15 +80,26 @@ class IEngine {
|
|||||||
*
|
*
|
||||||
* \sa CheckPoint, VersionNumber
|
* \sa CheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
|
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
|
utils::ISerializable *local_model = NULL) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||||
* every time we call check point, there is a version number which will increase by one
|
* every time we call check point, there is a version number which will increase by one
|
||||||
*
|
*
|
||||||
* \param p_model pointer to the model
|
* \param global_model pointer to the globally shared model/state
|
||||||
|
* when calling this function, the caller need to gauranttees that global_model
|
||||||
|
* is the same in all nodes
|
||||||
|
* \param local_model pointer to local model, that is specific to current node/rank
|
||||||
|
* this can be NULL when no local state is needed
|
||||||
|
*
|
||||||
|
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||||
|
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||||
|
* So only CheckPoint with global_model if possible
|
||||||
|
*
|
||||||
* \sa LoadCheckPoint, VersionNumber
|
* \sa LoadCheckPoint, VersionNumber
|
||||||
*/
|
*/
|
||||||
virtual void CheckPoint(const utils::ISerializable &model) = 0;
|
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||||
|
const utils::ISerializable *local_model = NULL) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \return version number of current stored model,
|
* \return version number of current stored model,
|
||||||
* which means how many calls to CheckPoint we made so far
|
* which means how many calls to CheckPoint we made so far
|
||||||
|
|||||||
@ -32,10 +32,12 @@ class MPIEngine : public IEngine {
|
|||||||
virtual void InitAfterException(void) {
|
virtual void InitAfterException(void) {
|
||||||
utils::Error("MPI is not fault tolerant");
|
utils::Error("MPI is not fault tolerant");
|
||||||
}
|
}
|
||||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
|
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||||
|
utils::ISerializable *local_model = NULL) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||||
|
const utils::ISerializable *local_model = NULL) {
|
||||||
version_number += 1;
|
version_number += 1;
|
||||||
}
|
}
|
||||||
virtual int VersionNumber(void) const {
|
virtual int VersionNumber(void) const {
|
||||||
|
|||||||
@ -129,7 +129,7 @@ inline int LoadCheckPoint(utils::ISerializable *p_model) {
|
|||||||
}
|
}
|
||||||
// checkpoint the model, meaning we finished a stage of execution
|
// checkpoint the model, meaning we finished a stage of execution
|
||||||
inline void CheckPoint(const utils::ISerializable &model) {
|
inline void CheckPoint(const utils::ISerializable &model) {
|
||||||
engine::GetEngine()->CheckPoint(model);
|
engine::GetEngine()->CheckPoint(&model);
|
||||||
}
|
}
|
||||||
// return the version number of currently stored model
|
// return the version number of currently stored model
|
||||||
inline int VersionNumber(void) {
|
inline int VersionNumber(void) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user