680 lines
26 KiB
C++
680 lines
26 KiB
C++
/*!
|
|
* \file allreduce_robust.cc
|
|
* \brief Robust implementation of Allreduce
|
|
*
|
|
* \author Tianqi Chen, Ignacio Cano, Tianyi Zhou
|
|
*/
|
|
#define _CRT_SECURE_NO_WARNINGS
|
|
#define _CRT_SECURE_NO_DEPRECATE
|
|
#define NOMINMAX
|
|
#include <limits>
|
|
#include <utility>
|
|
#include "./io.h"
|
|
#include "./utils.h"
|
|
#include "./allreduce_robust.h"
|
|
|
|
namespace rabit {
|
|
namespace engine {
|
|
AllreduceRobust::AllreduceRobust(void) {
|
|
result_buffer_round = 1;
|
|
seq_counter = 0;
|
|
}
|
|
/*! \brief shutdown the engine */
|
|
void AllreduceRobust::Shutdown(void) {
|
|
// need to sync the exec before we shutdown, do a pesudo check point
|
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
|
"check point must return true");
|
|
// reset result buffer
|
|
resbuf.Clear(); seq_counter = 0;
|
|
// execute check ack step, load happens here
|
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
|
"check ack must return true");
|
|
AllreduceBase::Shutdown();
|
|
}
|
|
/*!
|
|
* \brief set parameters to the engine
|
|
* \param name parameter name
|
|
* \param val parameter value
|
|
*/
|
|
void AllreduceRobust::SetParam(const char *name, const char *val) {
|
|
AllreduceBase::SetParam(name, val);
|
|
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
|
|
if (!strcmp(name, "result_replicate")) {
|
|
result_buffer_round = std::max(world_size / atoi(val), 1);
|
|
}
|
|
}
|
|
/*!
|
|
* \brief perform in-place allreduce, on sendrecvbuf
|
|
* this function is NOT thread-safe
|
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
* \param type_nbytes the unit number of bytes the type have
|
|
* \param count number of elements to be reduced
|
|
* \param reducer reduce function
|
|
*/
|
|
void AllreduceRobust::Allreduce(void *sendrecvbuf_,
|
|
size_t type_nbytes,
|
|
size_t count,
|
|
ReduceFunction reducer) {
|
|
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
|
// now we are free to remove the last result, if any
|
|
if (resbuf.LastSeqNo() != -1 &&
|
|
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
|
resbuf.DropLast();
|
|
}
|
|
void *temp = resbuf.AllocTemp(type_nbytes, count);
|
|
while (true) {
|
|
if (recovered) {
|
|
std::memcpy(temp, sendrecvbuf_, type_nbytes * count); break;
|
|
} else {
|
|
std::memcpy(temp, sendrecvbuf_, type_nbytes * count);
|
|
if (CheckAndRecover(TryAllreduce(temp, type_nbytes, count, reducer))) {
|
|
std::memcpy(sendrecvbuf_, temp, type_nbytes * count); break;
|
|
} else {
|
|
recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
|
}
|
|
}
|
|
}
|
|
resbuf.PushTemp(seq_counter, type_nbytes, count);
|
|
seq_counter += 1;
|
|
}
|
|
/*!
|
|
* \brief broadcast data from root to all nodes
|
|
* \param sendrecvbuf_ buffer for both sending and recving data
|
|
* \param size the size of the data to be broadcasted
|
|
* \param root the root worker id to broadcast the data
|
|
*/
|
|
void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
|
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
|
// now we are free to remove the last result, if any
|
|
if (resbuf.LastSeqNo() != -1 &&
|
|
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
|
resbuf.DropLast();
|
|
}
|
|
void *temp = resbuf.AllocTemp(1, total_size);
|
|
while (true) {
|
|
if (recovered) {
|
|
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
|
} else {
|
|
if (CheckAndRecover(TryBroadcast(sendrecvbuf_, total_size, root))) {
|
|
std::memcpy(temp, sendrecvbuf_, total_size); break;
|
|
} else {
|
|
recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
|
|
}
|
|
}
|
|
}
|
|
resbuf.PushTemp(seq_counter, 1, total_size);
|
|
seq_counter += 1;
|
|
}
|
|
/*!
|
|
* \brief load latest check point
|
|
* \param p_model pointer to the model
|
|
* \return the version number of check point loaded
|
|
* if returned version == 0, this means no model has been CheckPointed
|
|
* the p_model is not touched, user should do necessary initialization by themselves
|
|
* \sa CheckPoint, VersionNumber
|
|
*/
|
|
int AllreduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
|
|
// check if we succesfll
|
|
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
|
|
// reset result buffer
|
|
resbuf.Clear(); seq_counter = 0;
|
|
// load from buffer
|
|
utils::MemoryBufferStream fs(&checked_model);
|
|
fs.Read(&version_number, sizeof(version_number));
|
|
if (version_number == 0) return version_number;
|
|
p_model->Load(fs);
|
|
// run another phase of check ack, if recovered from data
|
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
|
"check ack must return true");
|
|
return version_number;
|
|
} else {
|
|
// reset result buffer
|
|
resbuf.Clear(); seq_counter = 0;
|
|
// nothing loaded, a fresh start, everyone init model
|
|
return false;
|
|
}
|
|
}
|
|
/*!
|
|
* \brief checkpoint the model, meaning we finished a stage of execution
|
|
* every time we call check point, there is a version number which will increase by one
|
|
*
|
|
* \param p_model pointer to the model
|
|
* \sa LoadCheckPoint, VersionNumber
|
|
*/
|
|
void AllreduceRobust::CheckPoint(const utils::ISerializable &model) {
|
|
// increase version number
|
|
version_number += 1;
|
|
// save model
|
|
checked_model.resize(0);
|
|
utils::MemoryBufferStream fs(&checked_model);
|
|
fs.Write(&version_number, sizeof(version_number));
|
|
model.Save(fs);
|
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kMaxSeq),
|
|
"check point must return true");
|
|
// reset result buffer
|
|
resbuf.Clear(); seq_counter = 0;
|
|
// execute check ack step, load happens here
|
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
|
|
"check ack must return true");
|
|
}
|
|
/*!
|
|
* \brief reset the all the existing links by sending Out-of-Band message marker
|
|
* after this function finishes, all the messages received and sent before in all live links are discarded,
|
|
* This allows us to get a fresh start after error has happened
|
|
*
|
|
* \return this function can return kSuccess or kSockError
|
|
* when kSockError is returned, it simply means there are bad sockets in the links,
|
|
* and some link recovery proceduer is needed
|
|
*/
|
|
AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
|
|
// number of links
|
|
const int nlink = static_cast<int>(links.size());
|
|
for (int i = 0; i < nlink; ++i) {
|
|
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
|
links[i].ResetSize();
|
|
}
|
|
// read and discard data from all channels until pass mark
|
|
while (true) {
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].sock.BadSocket()) continue;
|
|
if (links[i].size_write == 0) {
|
|
char sig = kOOBReset;
|
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
|
// error will be filtered in next loop
|
|
if (len == sizeof(sig)) links[i].size_write = 1;
|
|
}
|
|
if (links[i].size_write == 1) {
|
|
char sig = kResetMark;
|
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
|
|
if (len == sizeof(sig)) links[i].size_write = 2;
|
|
}
|
|
}
|
|
utils::SelectHelper rsel;
|
|
bool finished = true;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
|
|
rsel.WatchWrite(links[i].sock); finished = false;
|
|
}
|
|
}
|
|
if (finished) break;
|
|
// wait to read from the channels to discard data
|
|
rsel.Select();
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
utils::SelectHelper::WaitExcept(links[i].sock);
|
|
}
|
|
}
|
|
while (true) {
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_read == 0) {
|
|
int atmark = links[i].sock.AtMark();
|
|
if (atmark < 0) {
|
|
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
|
|
} else if (atmark > 0) {
|
|
links[i].size_read = 1;
|
|
} else {
|
|
// no at mark, read and discard data
|
|
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
|
if (links[i].sock.AtMark()) links[i].size_read = 1;
|
|
// zero length, remote closed the connection, close socket
|
|
if (len == 0) links[i].sock.Close();
|
|
}
|
|
}
|
|
}
|
|
utils::SelectHelper rsel;
|
|
bool finished = true;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
|
|
rsel.WatchRead(links[i].sock); finished = false;
|
|
}
|
|
}
|
|
if (finished) break;
|
|
rsel.Select();
|
|
}
|
|
|
|
// start synchronization, use blocking I/O to avoid select
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
char oob_mark;
|
|
links[i].sock.SetNonBlock(false);
|
|
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
|
if (len == 0) {
|
|
links[i].sock.Close(); continue;
|
|
} else if (len > 0) {
|
|
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
|
utils::Assert(links[i].sock.AtMark() != 1, "should already read past mark");
|
|
} else {
|
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
|
}
|
|
// send out ack
|
|
char ack = kResetAck;
|
|
while (true) {
|
|
len = links[i].sock.Send(&ack, sizeof(ack));
|
|
if (len == sizeof(ack)) break;
|
|
if (len == -1) {
|
|
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// wait all ack
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (!links[i].sock.BadSocket()) {
|
|
char ack;
|
|
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
|
if (len == 0) {
|
|
links[i].sock.Close(); continue;
|
|
} else if (len > 0) {
|
|
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
|
} else {
|
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
|
}
|
|
// set back to nonblock mode
|
|
links[i].sock.SetNonBlock(true);
|
|
}
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (links[i].sock.BadSocket()) return kSockError;
|
|
}
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to reconnect the broken links
|
|
* \return this function can kSuccess or kSockError
|
|
*/
|
|
AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) {
|
|
utils::Error("TryReConnectLinks: not implemented");
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief if err_type indicates an error
|
|
* recover links according to the error type reported
|
|
* if there is no error, return true
|
|
* \param err_type the type of error happening in the system
|
|
* \return true if err_type is kSuccess, false otherwise
|
|
*/
|
|
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
|
|
if (err_type == kSuccess) return true;
|
|
while(err_type != kSuccess) {
|
|
switch(err_type) {
|
|
case kGetExcept: err_type = TryResetLinks(); break;
|
|
case kSockError: {
|
|
TryResetLinks();
|
|
err_type = TryReConnectLinks();
|
|
break;
|
|
}
|
|
default: utils::Assert(false, "RecoverLinks: cannot reach here");
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
/*!
|
|
* \brief message passing function, used to decide the
|
|
* shortest distance to the possible source of data
|
|
* \param node_value a pair of have_data and size
|
|
* have_data whether current node have data
|
|
* size gives the size of data, if current node is kHaveData
|
|
* \param dist_in the shorest to any data source distance in each direction
|
|
* \param out_index the edge index of output link
|
|
* \return the shorest distance result of out edge specified by out_index
|
|
*/
|
|
inline std::pair<int,size_t>
|
|
ShortestDist(const std::pair<bool, size_t> &node_value,
|
|
const std::vector< std::pair<int, size_t> > &dist_in,
|
|
size_t out_index) {
|
|
if (node_value.first) {
|
|
return std::make_pair(1, node_value.second);
|
|
}
|
|
size_t size = 0;
|
|
int res = std::numeric_limits<int>::max();
|
|
for (size_t i = 0; i < dist_in.size(); ++i) {
|
|
if (i == out_index) continue;
|
|
if (dist_in[i].first == std::numeric_limits<int>::max()) continue;
|
|
if (dist_in[i].first + 1 < res) {
|
|
res = dist_in[i].first + 1;
|
|
size = dist_in[i].second;
|
|
}
|
|
}
|
|
// add one hop
|
|
|
|
return std::make_pair(res, size);
|
|
}
|
|
/*!
|
|
* \brief message passing function, used to decide the
|
|
* data request from each edge, whether need to request data from certain edge
|
|
* \param node_value a pair of request_data and best_link
|
|
* request_data stores whether current node need to request data
|
|
* best_link gives the best edge index to fetch the data
|
|
* \param req_in the data request from incoming edges
|
|
* \param out_index the edge index of output link
|
|
* \return the request to the output edge
|
|
*/
|
|
inline char DataRequest(const std::pair<bool, int> &node_value,
|
|
const std::vector<char> &req_in,
|
|
size_t out_index) {
|
|
// whether current node need to request data
|
|
bool request_data = node_value.first;
|
|
// which edge index is the best link to request data
|
|
// can be -1, which means current node contains data
|
|
const int best_link = node_value.second;
|
|
if (static_cast<int>(out_index) == best_link) {
|
|
if (request_data) return 1;
|
|
for (size_t i = 0; i < req_in.size(); ++i) {
|
|
if (i == out_index) continue;
|
|
if (req_in[i] != 0) return 1;
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
/*!
|
|
* \brief try to decide the recovery message passing request
|
|
* \param role the current role of the node
|
|
* \param p_size used to store the size of the message, for node in state kHaveData,
|
|
* this size must be set correctly before calling the function
|
|
* for others, this surves as output parameter
|
|
*
|
|
* \param p_recvlink used to store the link current node should recv data from, if necessary
|
|
* this can be -1, which means current node have the data
|
|
* \param p_req_in used to store the resulting vector, indicating which link we should send the data to
|
|
*
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType
|
|
*/
|
|
AllreduceRobust::ReturnType
|
|
AllreduceRobust::TryDecideRouting(AllreduceRobust::RecoverType role,
|
|
size_t *p_size,
|
|
int *p_recvlink,
|
|
std::vector<bool> *p_req_in) {
|
|
int best_link = -2;
|
|
{// get the shortest distance to the request point
|
|
std::vector< std::pair<int,size_t> > dist_in, dist_out;
|
|
ReturnType succ = MsgPassing(std::make_pair(role == kHaveData, *p_size),
|
|
&dist_in, &dist_out, ShortestDist);
|
|
if (succ != kSuccess) return succ;
|
|
if (role != kHaveData) {
|
|
for (size_t i = 0; i < dist_in.size(); ++i) {
|
|
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
|
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
|
"[%d] Allreduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
|
|
rank, dist_in[i].first, *p_size, dist_in[i].second);
|
|
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
|
best_link = static_cast<int>(i);
|
|
*p_size = dist_in[i].second;
|
|
}
|
|
}
|
|
}
|
|
utils::Check(best_link != -2, "Too many nodes went down and we cannot recover..");
|
|
} else {
|
|
best_link = -1;
|
|
}
|
|
}
|
|
// get the node request
|
|
std::vector<char> req_in, req_out;
|
|
ReturnType succ = MsgPassing(std::make_pair(role == kRequestData, best_link),
|
|
&req_in, &req_out, DataRequest);
|
|
if (succ != kSuccess) return succ;
|
|
// set p_req_in
|
|
p_req_in->resize(req_in.size());
|
|
for (size_t i = 0; i < req_in.size(); ++i) {
|
|
// set p_req_in
|
|
(*p_req_in)[i] = (req_in[i] != 0);
|
|
if (req_out[i] != 0) {
|
|
utils::Assert(req_in[i] == 0, "cannot get and receive request");
|
|
utils::Assert(static_cast<int>(i) == best_link, "request result inconsistent");
|
|
}
|
|
}
|
|
*p_recvlink = best_link;
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to finish the data recovery request,
|
|
* this function is used together with TryDecideRouting
|
|
* \param role the current role of the node
|
|
* \param sendrecvbuf_ the buffer to store the data to be sent/recived
|
|
* - if the role is kHaveData, this stores the data to be sent
|
|
* - if the role is kRequestData, this is the buffer to store the result
|
|
* - if the role is kPassData, this will not be used, and can be NULL
|
|
* \param size the size of the data, obtained from TryDecideRouting
|
|
* \param recv_link the link index to receive data, if necessary, obtained from TryDecideRouting
|
|
* \param req_in the request of each link to send data, obtained from TryDecideRouting
|
|
*
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType, TryDecideRouting
|
|
*/
|
|
AllreduceRobust::ReturnType
|
|
AllreduceRobust::TryRecoverData(RecoverType role,
|
|
void *sendrecvbuf_,
|
|
size_t size,
|
|
int recv_link,
|
|
const std::vector<bool> &req_in) {
|
|
// no need to run recovery for zero size message
|
|
if (links.size() == 0 || size == 0) return kSuccess;
|
|
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
|
const int nlink = static_cast<int>(links.size());
|
|
{
|
|
bool req_data = role == kRequestData;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (req_in[i]) {
|
|
utils::Assert(i != recv_link, "TryDecideRouting");
|
|
req_data = true;
|
|
}
|
|
}
|
|
// do not need to provide data or receive data, directly exit
|
|
if (!req_data) return kSuccess;
|
|
}
|
|
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
|
for (int i = 0; i < nlink; ++i) {
|
|
links[i].ResetSize();
|
|
}
|
|
while (true) {
|
|
bool finished = true;
|
|
utils::SelectHelper selecter;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (i == recv_link && links[i].size_read != size) {
|
|
selecter.WatchRead(links[i].sock);
|
|
finished = false;
|
|
}
|
|
if (req_in[i] && links[i].size_write != size) {
|
|
if (role == kHaveData ||
|
|
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
|
|
selecter.WatchWrite(links[i].sock);
|
|
}
|
|
finished = false;
|
|
}
|
|
selecter.WatchException(links[i].sock);
|
|
}
|
|
if (finished) break;
|
|
selecter.Select();
|
|
if (role == kRequestData) {
|
|
const int pid = recv_link;
|
|
if (selecter.CheckRead(links[pid].sock)) {
|
|
if(!links[pid].ReadToArray(sendrecvbuf_, size)) return kSockError;
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (req_in[i] && links[i].size_write != links[pid].size_read &&
|
|
selecter.CheckWrite(links[i].sock)) {
|
|
if(!links[i].WriteFromArray(sendrecvbuf_, links[pid].size_read)) return kSockError;
|
|
}
|
|
}
|
|
}
|
|
if (role == kHaveData) {
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (req_in[i] && selecter.CheckWrite(links[i].sock)) {
|
|
if(!links[i].WriteFromArray(sendrecvbuf_, size)) return kSockError;
|
|
}
|
|
}
|
|
}
|
|
if (role == kPassData) {
|
|
const int pid = recv_link;
|
|
const size_t buffer_size = links[pid].buffer_size;
|
|
if (selecter.CheckRead(links[pid].sock)) {
|
|
size_t min_write = size;
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
|
|
}
|
|
utils::Assert(min_write <= links[pid].size_read, "boundary check");
|
|
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
|
|
}
|
|
for (int i = 0; i < nlink; ++i) {
|
|
if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) {
|
|
size_t start = links[i].size_write % buffer_size;
|
|
// send out data from ring buffer
|
|
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
|
|
ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
|
|
if (len != -1) {
|
|
links[i].size_write += len;
|
|
} else {
|
|
if (errno != EAGAIN && errno != EWOULDBLOCK) return kSockError;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return kSuccess;
|
|
}
|
|
/*!
|
|
* \brief try to load check point
|
|
*
|
|
* This is a collaborative function called by all nodes
|
|
* only the nodes with requester set to true really needs to load the check point
|
|
* other nodes acts as collaborative roles to complete this request
|
|
*
|
|
* \param requester whether current node is the requester
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType
|
|
*/
|
|
AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
|
RecoverType role = requester ? kRequestData : kHaveData;
|
|
size_t size = this->checked_model.length();
|
|
int recv_link;
|
|
std::vector<bool> req_in;
|
|
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
|
if (succ != kSuccess) return succ;
|
|
if (role == kRequestData) {
|
|
checked_model.resize(size);
|
|
}
|
|
if (size == 0) return kSuccess;
|
|
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
|
|
}
|
|
/*!
|
|
* \brief try to get the result of operation specified by seqno
|
|
*
|
|
* This is a collaborative function called by all nodes
|
|
* only the nodes with requester set to true really needs to get the result
|
|
* other nodes acts as collaborative roles to complete this request
|
|
*
|
|
* \param buf the buffer to store the result, this parameter is only used when current node is requester
|
|
* \param size the total size of the buffer, this parameter is only used when current node is requester
|
|
* \param seqno sequence number of the operation, this is unique index of a operation in current iteration
|
|
* \param requester whether current node is the requester
|
|
* \return this function can return kSuccess/kSockError/kGetExcept, see ReturnType for details
|
|
* \sa ReturnType
|
|
*/
|
|
AllreduceRobust::ReturnType
|
|
AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool requester) { RecoverType role;
|
|
if (!requester) {
|
|
sendrecvbuf = resbuf.Query(seqno, &size);
|
|
role = sendrecvbuf != NULL ? kHaveData : kPassData;
|
|
} else {
|
|
role = kRequestData;
|
|
}
|
|
int recv_link;
|
|
std::vector<bool> req_in;
|
|
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
|
if (succ != kSuccess) return succ;
|
|
utils::Check(size != 0, "zero size check point is not allowed");
|
|
return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in);
|
|
}
|
|
/*!
|
|
* \brief try to run recover execution for a request action described by flag and seqno,
|
|
* the function will keep blocking to run possible recovery operations before the specified action,
|
|
* until the requested result is received by a recovering procedure,
|
|
* or the function discovers that the requested action is not yet executed, and return false
|
|
*
|
|
* \param buf the buffer to store the result
|
|
* \param size the total size of the buffer
|
|
* \param flag flag information about the action \sa ActionSummary
|
|
* \param seqno sequence number of the action, if it is special action with flag set,
|
|
* seqno needs to be set to ActionSummary::kMaxSeq
|
|
*
|
|
* \return if this function can return true or false
|
|
* - true means buf already set to the
|
|
* result by recovering procedure, the action is complete, no further action is needed
|
|
* - false means this is the lastest action that has not yet been executed, need to execute the action
|
|
*/
|
|
bool AllreduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
|
|
if (flag != 0) {
|
|
utils::Assert(seqno == ActionSummary::kMaxSeq, "must only set seqno for normal operations");
|
|
}
|
|
// request
|
|
ActionSummary req(flag, seqno);
|
|
while (true) {
|
|
// action
|
|
ActionSummary act = req;
|
|
// get the reduced action
|
|
if (!CheckAndRecover(TryAllreduce(&act, sizeof(act), 1, ActionSummary::Reducer))) continue;
|
|
if (act.check_ack()) {
|
|
if (act.check_point()) {
|
|
// if we also have check_point, do check point first
|
|
utils::Assert(!act.diff_seq(),
|
|
"check ack & check pt cannot occur together with normal ops");
|
|
// if we requested checkpoint, we are free to go
|
|
if (req.check_point()) return true;
|
|
} else if (act.load_check()) {
|
|
// if there is only check_ack and load_check, do load_check
|
|
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
|
// if requested load check, then misson complete
|
|
if (req.load_check()) return true;
|
|
} else {
|
|
// there is no check point and no load check, execute check ack
|
|
if (req.check_ack()) return true;
|
|
}
|
|
// if execute to this point
|
|
// this means the action requested has not been completed
|
|
// try next round
|
|
} else {
|
|
if (act.check_point()) {
|
|
if (act.diff_seq()) {
|
|
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
|
|
bool requester = req.min_seqno() == act.min_seqno();
|
|
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
|
if (requester) return true;
|
|
} else {
|
|
// no difference in seq no, means we are free to check point
|
|
if (req.check_point()) return true;
|
|
}
|
|
} else {
|
|
// no check point
|
|
if (act.load_check()) {
|
|
// all the nodes called load_check, this is an incomplete action
|
|
if (!act.diff_seq()) return false;
|
|
// load check have higher priority, do load_check
|
|
if (!CheckAndRecover(TryLoadCheckPoint(req.load_check()))) continue;
|
|
// if requested load check, then misson complete
|
|
if (req.load_check()) return true;
|
|
} else {
|
|
// no special flags, no checkpoint, check ack, load_check
|
|
utils::Assert(act.min_seqno() != ActionSummary::kMaxSeq, "min seq bug");
|
|
if (act.diff_seq()) {
|
|
bool requester = req.min_seqno() == act.min_seqno();
|
|
if (!CheckAndRecover(TryGetResult(buf, size, act.min_seqno(), requester))) continue;
|
|
if (requester) return true;
|
|
} else {
|
|
// all the request is same, this is most recent command that is yet to be executed
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
// something is still incomplete try next round
|
|
}
|
|
}
|
|
utils::Assert(false, "RecoverExec: should not reach here");
|
|
return true;
|
|
}
|
|
} // namespace engine
|
|
} // namespace rabit
|
|
|