find one bug, continue to next one

This commit is contained in:
tqchen
2014-12-01 19:34:27 -08:00
parent 2cde04867f
commit 993ff8bb91
3 changed files with 17 additions and 12 deletions

View File

@@ -16,7 +16,7 @@
namespace rabit {
namespace engine {
AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 1;
result_buffer_round = 2;
seq_counter = 0;
}
/*!
@@ -32,6 +32,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
size_t count,
ReduceFunction reducer) {
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
@@ -90,19 +91,21 @@ void AllReduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
bool AllReduceRobust::LoadCheckPoint(utils::ISerializable *p_model) {
// check if we succesfll
if (RecoverExec(NULL, 0, ActionSummary::kLoadCheck, ActionSummary::kMaxSeq)) {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// if loaded model is empty, this simply means we did not call checkpoint yet
// ask caller to reinit model
if (checked_model.length() == 0) return false;
// load from buffer
utils::MemoryBufferStream fs(&checked_model);
p_model->Load(fs);
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kMaxSeq),
"check ack must return true");
return true;
} else {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// nothing loaded, a fresh start, everyone init model
return false;
}
@@ -362,7 +365,7 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
for (size_t i = 0; i < dist_in.size(); ++i) {
if (dist_in[i].first != std::numeric_limits<int>::max()) {
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
"AllReduce size inconsistent");
"AllReduce size inconsistent, size=%lu, reporting=%lu", *p_size, dist_in[i].second);
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
best_link = static_cast<int>(i);
*p_size = dist_in[i].second;
@@ -413,6 +416,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
size_t size,
int recv_link,
const std::vector<bool> &req_in) {
utils::LogPrintf("[%d] recv_link=%d\n", rank, recv_link);
// no need to run recovery for zero size message
if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
@@ -519,7 +523,7 @@ AllReduceRobust::ReturnType AllReduceRobust::TryLoadCheckPoint(bool requester) {
if (role == kRequestData) {
checked_model.resize(size);
}
utils::Check(size != 0, "zero size check point is not allowed");
if (size == 0) return kSuccess;
return TryRecoverData(role, &checked_model[0], size, recv_link, req_in);
}
/*!
@@ -574,6 +578,7 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
}
// request
ActionSummary req(flag, seqno);
utils::LogPrintf("[%d] propose flag=%d, seq=%d\n", rank, flag, seqno);
while (true) {
// action
ActionSummary act = req;