add local model in checkpoint interface, a new goal
This commit is contained in:
parent
79e7862583
commit
cc410b8c90
@ -348,7 +348,7 @@ AllreduceBase::TryAllreduce(void *sendrecvbuf_,
|
||||
}
|
||||
if (len != -1) {
|
||||
size_down_in += static_cast<size_t>(len);
|
||||
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
||||
utils::Assert(size_down_in <= size_up_out, "Allreduce: boundary error");
|
||||
} else {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
}
|
||||
|
||||
@ -84,23 +84,48 @@ class AllreduceBase : public IEngine {
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
/*!
|
||||
@ -267,6 +292,8 @@ class AllreduceBase : public IEngine {
|
||||
int parent_rank;
|
||||
// sockets of all links
|
||||
std::vector<LinkRecord> links;
|
||||
// pointer to someplace in the ring
|
||||
LinkRecord *ring_prev, *ring_next;
|
||||
//----- meta information-----
|
||||
// unique identifier of the possible job this process is doing
|
||||
// used to assign ranks, optional, default to NULL
|
||||
|
||||
@ -17,6 +17,7 @@ namespace rabit {
|
||||
namespace engine {
|
||||
AllreduceRobust::AllreduceRobust(void) {
|
||||
result_buffer_round = 1;
|
||||
num_local_replica = 2;
|
||||
seq_counter = 0;
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
@ -108,22 +109,38 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
|
||||
}
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model) {
|
||||
utils::Check(local_model == NULL, "CheckPoint local_model is not yet supported");
|
||||
// check if we succesfll
|
||||
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// load from buffer
|
||||
utils::MemoryBufferStream fs(&checked_model);
|
||||
utils::MemoryBufferStream fs(&mglobal_model);
|
||||
fs.Read(&version_number, sizeof(version_number));
|
||||
if (version_number == 0) return version_number;
|
||||
p_model->Load(fs);
|
||||
global_model->Load(fs);
|
||||
// run another phase of check ack, if recovered from data
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
||||
"check ack must return true");
|
||||
@ -139,20 +156,31 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
void AllreduceRobust::CheckPoint(const utils::ISerializable &model) {
|
||||
// increase version number
|
||||
version_number += 1;
|
||||
// save model
|
||||
checked_model.resize(0);
|
||||
utils::MemoryBufferStream fs(&checked_model);
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
model.Save(fs);
|
||||
void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model) {
|
||||
utils::Assert(local_model == NULL, "CheckPoint local model is not supported yet");
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
||||
"check point must return true");
|
||||
// increase version number
|
||||
version_number += 1;
|
||||
// save model
|
||||
mglobal_model.resize(0);
|
||||
utils::MemoryBufferStream fs(&mglobal_model);
|
||||
fs.Write(&version_number, sizeof(version_number));
|
||||
global_model->Save(fs);
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
@ -488,6 +516,10 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
}
|
||||
if (finished) break;
|
||||
selecter.Select();
|
||||
// exception handling
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
if (selecter.CheckExcept(links[i].sock)) return kGetExcept;
|
||||
}
|
||||
if (role == kRequestData) {
|
||||
const int pid = recv_link;
|
||||
if (selecter.CheckRead(links[pid].sock)) {
|
||||
@ -548,16 +580,16 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
*/
|
||||
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
||||
RecoverType role = requester ? kRequestData : kHaveData;
|
||||
size_t size = this->checked_model.length();
|
||||
size_t size = this->mglobal_model.length();
|
||||
int recv_link;
|
||||
std::vector<bool> req_in;
|
||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
if (succ != kSuccess) return succ;
|
||||
if (role == kRequestData) {
|
||||
checked_model.resize(size);
|
||||
mglobal_model.resize(size);
|
||||
}
|
||||
if (size == 0) return kSuccess;
|
||||
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
|
||||
return TryRecoverData(role, &mglobal_model[0], size, recv_link, req_in);
|
||||
}
|
||||
/*!
|
||||
* \brief try to get the result of operation specified by seqno
|
||||
@ -674,6 +706,81 @@ bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
||||
utils::Assert(false, "RecoverExec: should not reach here");
|
||||
return true;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||
* write_ptr will wait till the data is readed before sending the data
|
||||
* this function requires read_end >= write_end
|
||||
*
|
||||
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||
* \param read_ptr the initial read pointer
|
||||
* \param read_end the ending position to read
|
||||
* \param write_ptr the initial write pointer
|
||||
* \param write_end the ending position to write
|
||||
* \param prev pointer to link to previous position in ring
|
||||
* \param prev pointer to link of next position in ring
|
||||
*/
|
||||
AllreduceRobust::ReturnType
|
||||
AllreduceRobust::RingPassing(void *sendrecvbuf_,
|
||||
size_t read_ptr,
|
||||
size_t read_end,
|
||||
size_t write_ptr,
|
||||
size_t write_end,
|
||||
LinkRecord *prev_link,
|
||||
LinkRecord *next_link) {
|
||||
if (links.size() == 0 || read_end == 0) return kSuccess;
|
||||
utils::Assert(read_end <= write_end, "boundary check");
|
||||
utils::Assert(read_ptr <= read_end, "boundary check");
|
||||
utils::Assert(write_ptr <= write_end, "boundary check");
|
||||
// take reference
|
||||
LinkRecord &prev = *prev_link, &next = *next_link;
|
||||
// send recv buffer
|
||||
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
|
||||
while (true) {
|
||||
bool finished = true;
|
||||
utils::SelectHelper selecter;
|
||||
if (read_ptr != read_end) {
|
||||
selecter.WatchRead(prev.sock);
|
||||
finished = false;
|
||||
}
|
||||
if (write_ptr < read_ptr && write_ptr != write_end) {
|
||||
selecter.WatchWrite(next.sock);
|
||||
finished = false;
|
||||
}
|
||||
selecter.WatchException(prev.sock);
|
||||
selecter.WatchException(next.sock);
|
||||
if (finished) break;
|
||||
selecter.Select();
|
||||
if (selecter.CheckExcept(prev.sock)) return kGetExcept;
|
||||
if (selecter.CheckExcept(next.sock)) return kGetExcept;
|
||||
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
|
||||
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
|
||||
if (len == 0) {
|
||||
prev.sock.Close(); return kSockError;
|
||||
}
|
||||
if (len != -1) {
|
||||
read_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
}
|
||||
}
|
||||
if (write_ptr != write_end && write_ptr < read_ptr &&
|
||||
selecter.CheckWrite(next.sock)) {
|
||||
size_t nsend = std::min(write_end - write_ptr, read_ptr - write_ptr);
|
||||
ssize_t len = next.sock.Send(buf + write_ptr, nsend);
|
||||
if (len != -1) {
|
||||
write_ptr += static_cast<size_t>(len);
|
||||
} else {
|
||||
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
||||
}
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
|
||||
@ -49,21 +49,46 @@ class AllreduceRobust : public AllreduceBase {
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root);
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
*
|
||||
* Common usage example:
|
||||
* int iter = rabit::LoadCheckPoint(&model);
|
||||
* if (iter == 0) model.InitParameters();
|
||||
* for (i = iter; i < max_iter; ++i) {
|
||||
* do many things, include allreduce
|
||||
* rabit::CheckPoint(model);
|
||||
* }
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *p_model);
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable &model);
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL);
|
||||
/*!
|
||||
* \brief explicitly re-init everything before calling LoadCheckPoint
|
||||
* call this function when IEngine throw an exception out,
|
||||
@ -259,7 +284,7 @@ class AllreduceRobust : public AllreduceBase {
|
||||
* result by recovering procedure, the action is complete, no further action is needed
|
||||
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
||||
*/
|
||||
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
|
||||
bool RecoverExec(void *buf, size_t size, int flag, int seqno = ActionSummary::kMaxSeq);
|
||||
/*!
|
||||
* \brief try to load check point
|
||||
*
|
||||
@ -325,6 +350,30 @@ class AllreduceRobust : public AllreduceBase {
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in);
|
||||
/*!
|
||||
* \brief perform a ring passing to receive data from prev link, and sent data to next link
|
||||
* this allows data to stream over a ring structure
|
||||
* sendrecvbuf[0:read_ptr] are already provided by current node
|
||||
* current node will recv sendrecvbuf[read_ptr:read_end] from prev link
|
||||
* current node will send sendrecvbuf[write_ptr:write_end] to next link
|
||||
* write_ptr will wait till the data is readed before sending the data
|
||||
* this function requires read_end >= write_end
|
||||
*
|
||||
* \param sendrecvbuf_ the place to hold the incoming and outgoing data
|
||||
* \param read_ptr the initial read pointer
|
||||
* \param read_end the ending position to read
|
||||
* \param write_ptr the initial write pointer
|
||||
* \param write_end the ending position to write
|
||||
* \param prev pointer to link to previous position in ring
|
||||
* \param prev pointer to link of next position in ring
|
||||
*/
|
||||
ReturnType RingPassing(void *senrecvbuf_,
|
||||
size_t read_ptr,
|
||||
size_t read_end,
|
||||
size_t write_ptr,
|
||||
size_t write_end,
|
||||
LinkRecord *prev_link,
|
||||
LinkRecord *next_link);
|
||||
/*!
|
||||
* \brief run message passing algorithm on the allreduce tree
|
||||
* the result is edge message stored in p_edge_in and p_edge_out
|
||||
@ -358,10 +407,21 @@ class AllreduceRobust : public AllreduceBase {
|
||||
int seq_counter;
|
||||
// the round of result buffer, used to mode the result
|
||||
int result_buffer_round;
|
||||
// result buffer
|
||||
// result buffer of all reduce
|
||||
ResultBuffer resbuf;
|
||||
// last check point model
|
||||
std::string checked_model;
|
||||
// last check point global model
|
||||
std::string mglobal_model;
|
||||
// number of replica for local state/model
|
||||
int num_local_replica;
|
||||
// pointer to memory position in the local model
|
||||
// local model is stored in CSR format(like a sparse matrices)
|
||||
// local_model[rptr[0]:rptr[1]] stores the model of current node
|
||||
// local_model[rptr[k]:rptr[k+1]] stores the model of node in previous k hops in the ring
|
||||
std::vector<size_t> local_rptr;
|
||||
// storage for local model replicas
|
||||
std::string mlocal_model;
|
||||
// temporal storage
|
||||
std::string tmp_local_model;
|
||||
};
|
||||
} // namespace engine
|
||||
} // namespace rabit
|
||||
|
||||
24
src/engine.h
24
src/engine.h
@ -60,7 +60,12 @@ class IEngine {
|
||||
virtual void InitAfterException(void) = 0;
|
||||
/*!
|
||||
* \brief load latest check point
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local model is needed
|
||||
*
|
||||
* \return the version number of check point loaded
|
||||
* if returned version == 0, this means no model has been CheckPointed
|
||||
* the p_model is not touched, user should do necessary initialization by themselves
|
||||
@ -75,15 +80,26 @@ class IEngine {
|
||||
*
|
||||
* \sa CheckPoint, VersionNumber
|
||||
*/
|
||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) = 0;
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \brief checkpoint the model, meaning we finished a stage of execution
|
||||
* every time we call check point, there is a version number which will increase by one
|
||||
*
|
||||
* \param p_model pointer to the model
|
||||
* \param global_model pointer to the globally shared model/state
|
||||
* when calling this function, the caller need to gauranttees that global_model
|
||||
* is the same in all nodes
|
||||
* \param local_model pointer to local model, that is specific to current node/rank
|
||||
* this can be NULL when no local state is needed
|
||||
*
|
||||
* NOTE: local_model requires explicit replication of the model for fault-tolerance, which will
|
||||
* bring replication cost in CheckPoint function. global_model do not need explicit replication.
|
||||
* So only CheckPoint with global_model if possible
|
||||
*
|
||||
* \sa LoadCheckPoint, VersionNumber
|
||||
*/
|
||||
virtual void CheckPoint(const utils::ISerializable &model) = 0;
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) = 0;
|
||||
/*!
|
||||
* \return version number of current stored model,
|
||||
* which means how many calls to CheckPoint we made so far
|
||||
|
||||
@ -32,10 +32,12 @@ class MPIEngine : public IEngine {
|
||||
virtual void InitAfterException(void) {
|
||||
utils::Error("MPI is not fault tolerant");
|
||||
}
|
||||
virtual int LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
virtual int LoadCheckPoint(utils::ISerializable *global_model,
|
||||
utils::ISerializable *local_model = NULL) {
|
||||
return 0;
|
||||
}
|
||||
virtual void CheckPoint(const utils::ISerializable &model) {
|
||||
virtual void CheckPoint(const utils::ISerializable *global_model,
|
||||
const utils::ISerializable *local_model = NULL) {
|
||||
version_number += 1;
|
||||
}
|
||||
virtual int VersionNumber(void) const {
|
||||
|
||||
@ -129,7 +129,7 @@ inline int LoadCheckPoint(utils::ISerializable *p_model) {
|
||||
}
|
||||
// checkpoint the model, meaning we finished a stage of execution
|
||||
inline void CheckPoint(const utils::ISerializable &model) {
|
||||
engine::GetEngine()->CheckPoint(model);
|
||||
engine::GetEngine()->CheckPoint(&model);
|
||||
}
|
||||
// return the version number of currently stored model
|
||||
inline int VersionNumber(void) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user