xgboost/src/engine_robust.cc
2014-11-30 11:45:50 -08:00

334 lines
12 KiB
C++

/*!
* \file engine_robust.cc
* \brief Robust implementation of AllReduce
* \author Tianqi, Nacho, Tianyi
*/
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include "./utils.h"
#include "./engine_robust.h"
namespace engine {
/*!
* \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe
* \param sendrecvbuf_ buffer for both sending and recving data
* \param type_nbytes the unit number of bytes the type have
* \param count number of elements to be reduced
* \param reducer reduce function
*/
void AllReduceRobust::AllReduce(void *sendrecvbuf_,
size_t type_nbytes,
size_t count,
ReduceFunction reducer) {
while (true) {
ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer);
if (ret == kSuccess) return;
if (ret == kSockError) {
utils::Error("error occur during all reduce\n");
}
utils::LogPrintf("[%d] receive except signal, start reset link\n", rank);
TryResetLinks();
}
// TODO
}
/*!
* \brief broadcast data from root to all nodes
* \param sendrecvbuf_ buffer for both sending and recving data
* \param size the size of the data to be broadcasted
* \param root the root worker id to broadcast the data
*/
void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess,
"AllReduce failed");
// TODO
}
/*!
* \brief load latest check point
* \param p_model pointer to the model
* \return true if there was stored checkpoint and load was successful
* false if there was no stored checkpoint, means we are start over gain
*/
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// TODO
return false;
}
/*!
* \brief checkpoint the model, meaning we finished a stage of execution
* \param p_model pointer to the model
*/
void AllReduceRobust::CheckPoint(const utils::ISerializable &model) {
// TODO
}
/*!
* \brief reset the all the existing links by sending Out-of-Band message marker
* after this function finishes, all the messages received and sent before in all live links are discarded,
* This allows us to get a fresh start after error has happened
*
* \return this function can return kSuccess or kSockError
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
AllReduceRobust::ReturnType AllReduceRobust::TryResetLinks(void) {
// number of links
const int nlink = static_cast<int>(links.size());
for (int i = 0; i < nlink; ++i) {
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
links[i].ResetSize();
}
// read and discard data from all channels until pass mark
while (true) {
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) continue;
if (links[i].size_write == 0) {
char sig = kOOBReset;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
// error will be filtered in next loop
if (len == sizeof(sig)) links[i].size_write = 1;
}
if (links[i].size_write == 1) {
char sig = kResetMark;
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
if (len == sizeof(sig)) links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
rsel.WatchWrite(links[i].sock); finished = false;
}
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
}
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(links[i].sock);
}
}
while (true) {
for (int i = 0; i < nlink; ++i) {
if (links[i].size_read == 0) {
int atmark = links[i].sock.AtMark();
if (atmark < 0) {
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
} else if (atmark > 0) {
links[i].size_read = 1;
} else {
// no at mark, read and discard data
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
if (links[i].sock.AtMark()) links[i].size_read = 1;
// zero length, remote closed the connection, close socket
if (len == 0) links[i].sock.Close();
}
}
}
utils::SelectHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
rsel.WatchRead(links[i].sock); finished = false;
}
}
if (finished) break;
rsel.Select();
}
// start synchronization, use blocking I/O to avoid select
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
char oob_mark;
links[i].sock.SetNonBlock(false);
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
if (len == 0) {
links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// send out ack
char ack = kResetAck;
while (true) {
len = links[i].sock.Send(&ack, sizeof(ack));
if (len == sizeof(ack)) break;
if (len == -1) {
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
}
}
}
}
// wait all ack
for (int i = 0; i < nlink; ++i) {
if (!links[i].sock.BadSocket()) {
char ack;
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
if (len == 0) {
links[i].sock.Close(); continue;
} else if (len > 0) {
utils::Assert(ack == kResetAck, "wrong Ack MSG");
} else {
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
}
// set back to nonblock mode
links[i].sock.SetNonBlock(true);
}
}
for (int i = 0; i < nlink; ++i) {
if (links[i].sock.BadSocket()) return kSockError;
}
return kSuccess;
}
/*!
* \brief try to reconnect the broken links
* \return this function can kSuccess or kSockError
*/
AllReduceRobust::ReturnType AllReduceRobust::TryReConnectLinks(void) {
utils::Error("TryReConnectLinks: not implemented");
return kSuccess;
}
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
* if there is no error, return true
* \param err_type the type of error happening in the system
* \return true if err_type is kSuccess, false otherwise
*/
bool AllReduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true;
while(err_type != kSuccess) {
switch(err_type) {
case kGetExcept: err_type = TryResetLinks(); break;
case kSockError: {
TryResetLinks();
err_type = TryReConnectLinks();
break;
}
default: utils::Assert(false, "RecoverLinks: cannot reach here");
}
}
return false;
}
/*!
* \brief try to load check point
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to load the check point
* other nodes acts as collaborative roles to complete this request
*
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
utils::Error("TryLoadCheckPoint: not implemented");
return kSuccess;
}
/*!
* \brief try to get the result of operation specified by seqno
*
* This is a collaborative function called by all nodes
* only the nodes with requester set to true really needs to get the result
* other nodes acts as collaborative roles to complete this request
*
* \param buf the buffer to store the result, this parameter is only use when current node is requester
* \param size the total size of the buffer, this parameter is only use when current node is requester
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
* \param requester whether current node is the requester
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
* \sa ReturnType
*/
AllReduceRobust::ReturnType AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) {
utils::Error("TryGetResult: not implemented");
return kSuccess;
}
/*!
* \brief try to run recover execution for a request action described by flag and seqno,
* the function will keep blocking to run possible recovery operations before the specified action,
* until the requested result is received by a recovering procedure,
* or the function discovers that the requested action is not yet executed, and return false
*
* \param buf the buffer to store the result
* \param size the total size of the buffer
* \param flag flag information about the action \sa ActionSummary
* \param seqno sequence number of the action, if it is special action with flag set,
* seqno needs to be set to ActionSummary::kMaxSeq
*
* \return if this function can return true or false
* - true means buf already set to the
* result by recovering procedure, the action is complete, no further action is needed
* - false means this is the lastest action that has not yet been executed, need to execute the action
*/
bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
if (flag != 0) {
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
}
// request
ActionSummary req(flag, seqno);
while (true) {
// action
ActionSummary act = req;
// get the reduced action
if (!CheckAndRecover(TryAllReduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
if (act.check_ack()) {
if (act.check_point()) {
// if we also have check_point, do check point first
utils::Assert(!act.diff_seq(),
"check ack & check pt cannot occur together with normal ops");
// if we requested checkpoint, we are free to go
if (req.check_point()) return true;
} else if (act.load_check()) {
// if there is only check_ack and load_check, do load_check
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
// if requested load check, then misson complete
if (req.load_check()) return true;
} else {
// there is no check point and no load check, execute check ack
if (req.check_ack()) return true;
}
// if execute to this point
// this means the action requested has not been completed
// try next round
} else {
if (act.check_point()) {
if (act.diff_seq()) {
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
bool requester = req.min_seqno() == act.min_seqno();
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true;
} else {
// no difference in seq no, means we are free to check point
if (req.check_point()) return true;
}
} else {
// no check point
if (act.load_check()) {
// load check have higher priority, do load_check
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
// if requested load check, then misson complete
if (req.load_check()) return true;
} else {
// no special flags, no checkpoint, check ack, load_check
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
if (act.diff_seq()) {
bool requester = req.min_seqno() == act.min_seqno();
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
if (requester) return true;
} else {
// all the request is same, this is most recent command that is yet to be executed
return false;
}
}
}
// something is still incomplete try next round
}
}
utils::Assert(false, "RecoverExec: should not reach here");
return true;
}
} // namespace engine