334 lines
12 KiB
C++
334 lines
12 KiB
C++
/*!
|
|
* \file engine_robust.cc
|
|
* \brief Robust implementation of AllReduce
|
|
* \author Tianqi, Nacho, Tianyi
|
|
*/
|
|
#define _CRT_SECURE_NO_WARNINGS
|
|
#define _CRT_SECURE_NO_DEPRECATE
|
|
#define NOMINMAX
|
|
#include "./utils.h"
|
|
#include "./engine_robust.h"
|
|
|
|
namespace engine {
|
|
/*!
|
|
* \brief perform in-place allreduce, on sendrecvbuf
|
|
* this function is NOT thread-safe
|
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
* \param type_nbytes the unit number of bytes the type have
|
|
* \param count number of elements to be reduced
|
|
* \param reducer reduce function
|
|
*/
|
|
void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
|
size_t type_nbytes,
|
|
size_t count,
|
|
ReduceFunction reducer) {
|
|
while (true) {
|
|
ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer);
|
|
if (ret == kSuccess) return;
|
|
if (ret == kSockError) {
|
|
utils::Error("error occur during all reduce\n");
|
|
}
|
|
utils::LogPrintf("[%d] receive except signal, start reset link\n", rank);
|
|
TryResetLinks();
|
|
}
|
|
// TODO
|
|
}
|
|
/*!
|
|
* \brief broadcast data from root to all nodes
|
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
* \param size the size of the data to be broadcasted
|
|
* \param root the root worker id to broadcast the data
|
|
*/
|
|
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
|
|
"AllReduce failed");
|
|
// TODO
|
|
}
|
|
/*!
|
|
* \brief load latest check point
|
|
* \param p_model pointer to the model
|
|
* \return true if there was stored checkpoint and load was successful
|
|
* false if there was no stored checkpoint, means we are start over gain
|
|
*/
|
|
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
|
// TODO
|
|
return false;
|
|
}
|
|
/*!
|
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
|
* \param p_model pointer to the model
|
|
*/
|
|
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
|
|
// TODO
|
|
}
|
|
/*!
|
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
|
* after this function finishes, all the messages received and sent before in all live links are discarded,
|
|
* This allows us to get a fresh start after error has happened
|
|
*
|
|
* \return this function can return kSuccess or kSockError
|
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
|
* and some link recovery proceduer is needed
|
|
*/
|
|
AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) {
|
|
// number of links
|
|
const int nlink = static_cast<int>(links.size());
|
|
for (int i = 0; i < nlink; ++i) {
|
|
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
|
links[i].ResetSize();
|
|
}
|
|
// read and discard data from all channels until pass mark
|
|
while (true) {
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].sock.BadSocket()) continue;
|
|
if (links[i].size_write == 0) {
|
|
char sig = kOOBReset;
|
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
|
// error will be filtered in next loop
|
|
if (len == sizeof(sig)) links[i].size_write = 1;
|
|
}
|
|
if (links[i].size_write == 1) {
|
|
char sig = kResetMark;
|
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
|
|
if (len == sizeof(sig)) links[i].size_write = 2;
|
|
}
|
|
}
|
|
utils::SelectHelper rsel;
|
|
bool finished = true;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
|
|
rsel.WatchWrite(links[i].sock); finished = false;
|
|
}
|
|
}
|
|
if (finished) break;
|
|
// wait to read from the channels to discard data
|
|
rsel.Select();
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
utils::SelectHelper::WaitExcept(links[i].sock);
|
|
}
|
|
}
|
|
while (true) {
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_read == 0) {
|
|
int atmark = links[i].sock.AtMark();
|
|
if (atmark < 0) {
|
|
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
|
|
} else if (atmark > 0) {
|
|
links[i].size_read = 1;
|
|
} else {
|
|
// no at mark, read and discard data
|
|
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
|
if (links[i].sock.AtMark()) links[i].size_read = 1;
|
|
// zero length, remote closed the connection, close socket
|
|
if (len == 0) links[i].sock.Close();
|
|
}
|
|
}
|
|
}
|
|
utils::SelectHelper rsel;
|
|
bool finished = true;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
|
|
rsel.WatchRead(links[i].sock); finished = false;
|
|
}
|
|
}
|
|
if (finished) break;
|
|
rsel.Select();
|
|
}
|
|
|
|
// start synchronization, use blocking I/O to avoid select
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
char oob_mark;
|
|
links[i].sock.SetNonBlock(false);
|
|
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
|
if (len == 0) {
|
|
links[i].sock.Close(); continue;
|
|
} else if (len > 0) {
|
|
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
|
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
|
|
} else {
|
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
|
}
|
|
// send out ack
|
|
char ack = kResetAck;
|
|
while (true) {
|
|
len = links[i].sock.Send(&ack, sizeof(ack));
|
|
if (len == sizeof(ack)) break;
|
|
if (len == -1) {
|
|
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// wait all ack
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
char ack;
|
|
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
|
if (len == 0) {
|
|
links[i].sock.Close(); continue;
|
|
} else if (len > 0) {
|
|
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
|
} else {
|
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
|
}
|
|
// set back to nonblock mode
|
|
links[i].sock.SetNonBlock(true);
|
|
}
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].sock.BadSocket()) return kSockError;
|
|
}
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to reconnect the broken links
|
|
* \return this function can kSuccess or kSockError
|
|
*/
|
|
AllReduceRobust::ReturnType AllReduceRobust::TryReConnectLinks(void) {
|
|
utils::Error("TryReConnectLinks: not implemented");
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief if err_type indicates an error
|
|
* recover links according to the error type reported
|
|
* if there is no error, return true
|
|
* \param err_type the type of error happening in the system
|
|
* \return true if err_type is kSuccess, false otherwise
|
|
*/
|
|
bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
|
|
if (err_type == kSuccess) return true;
|
|
while(err_type != kSuccess) {
|
|
switch(err_type) {
|
|
case kGetExcept: err_type = TryResetLinks(); break;
|
|
case kSockError: {
|
|
TryResetLinks();
|
|
err_type = TryReConnectLinks();
|
|
break;
|
|
}
|
|
default: utils::Assert(false, "RecoverLinks: cannot reach here");
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
/*!
|
|
* \brief try to load check point
|
|
*
|
|
* This is a collaborative function called by all nodes
|
|
* only the nodes with requester set to true really needs to load the check point
|
|
* other nodes acts as collaborative roles to complete this request
|
|
*
|
|
* \param requester whether current node is the requester
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType
|
|
*/
|
|
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
|
|
utils::Error("TryLoadCheckPoint: not implemented");
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to get the result of operation specified by seqno
|
|
*
|
|
* This is a collaborative function called by all nodes
|
|
* only the nodes with requester set to true really needs to get the result
|
|
* other nodes acts as collaborative roles to complete this request
|
|
*
|
|
* \param buf the buffer to store the result, this parameter is only use when current node is requester
|
|
* \param size the total size of the buffer, this parameter is only use when current node is requester
|
|
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
|
|
* \param requester whether current node is the requester
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType
|
|
*/
|
|
AllReduceRobust::ReturnType AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
|
|
utils::Error("TryGetResult: not implemented");
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to run recover execution for a request action described by flag and seqno,
|
|
* the function will keep blocking to run possible recovery operations before the specified action,
|
|
* until the requested result is received by a recovering procedure,
|
|
* or the function discovers that the requested action is not yet executed, and return false
|
|
*
|
|
* \param buf the buffer to store the result
|
|
* \param size the total size of the buffer
|
|
* \param flag flag information about the action \sa ActionSummary
|
|
* \param seqno sequence number of the action, if it is special action with flag set,
|
|
* seqno needs to be set to ActionSummary::kMaxSeq
|
|
*
|
|
* \return if this function can return true or false
|
|
* - true means buf already set to the
|
|
* result by recovering procedure, the action is complete, no further action is needed
|
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
|
*/
|
|
bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|
if (flag != 0) {
|
|
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
|
|
}
|
|
// request
|
|
ActionSummary req(flag, seqno);
|
|
while (true) {
|
|
// action
|
|
ActionSummary act = req;
|
|
// get the reduced action
|
|
if (!CheckAndRecover(TryAllReduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
|
if (act.check_ack()) {
|
|
if (act.check_point()) {
|
|
// if we also have check_point, do check point first
|
|
utils::Assert(!act.diff_seq(),
|
|
"check ack & check pt cannot occur together with normal ops");
|
|
// if we requested checkpoint, we are free to go
|
|
if (req.check_point()) return true;
|
|
} else if (act.load_check()) {
|
|
// if there is only check_ack and load_check, do load_check
|
|
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
|
// if requested load check, then misson complete
|
|
if (req.load_check()) return true;
|
|
} else {
|
|
// there is no check point and no load check, execute check ack
|
|
if (req.check_ack()) return true;
|
|
}
|
|
// if execute to this point
|
|
// this means the action requested has not been completed
|
|
// try next round
|
|
} else {
|
|
if (act.check_point()) {
|
|
if (act.diff_seq()) {
|
|
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
|
|
bool requester = req.min_seqno() == act.min_seqno();
|
|
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
|
if (requester) return true;
|
|
} else {
|
|
// no difference in seq no, means we are free to check point
|
|
if (req.check_point()) return true;
|
|
}
|
|
} else {
|
|
// no check point
|
|
if (act.load_check()) {
|
|
// load check have higher priority, do load_check
|
|
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
|
// if requested load check, then misson complete
|
|
if (req.load_check()) return true;
|
|
} else {
|
|
// no special flags, no checkpoint, check ack, load_check
|
|
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
|
|
if (act.diff_seq()) {
|
|
bool requester = req.min_seqno() == act.min_seqno();
|
|
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
|
if (requester) return true;
|
|
} else {
|
|
// all the request is same, this is most recent command that is yet to be executed
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
// something is still incomplete try next round
|
|
}
|
|
}
|
|
utils::Assert(false, "RecoverExec: should not reach here");
|
|
return true;
|
|
}
|
|
} // namespace engine
|