nice fix, start check local check

This commit is contained in:
tqchen
2014-12-18 18:39:24 -08:00
parent 3f22596e3c
commit dbd05a65b5
2 changed files with 29 additions and 29 deletions

View File

@@ -151,18 +151,21 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
resbuf.Clear(); seq_counter = 0;
// load from buffer
utils::MemoryBufferStream fs(&global_checkpoint);
fs.Read(&version_number, sizeof(version_number));
if (version_number == 0) return version_number;
global_model->Load(fs);
if (global_checkpoint.length() == 0) {
version_number = 0;
} else {
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number");
global_model->Load(fs);
}
// run another phase of check ack, if recovered from data
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"check ack must return true");
return version_number;
} else {
// reset result buffer
resbuf.Clear(); seq_counter = 0;
resbuf.Clear(); seq_counter = 0; version_number = 0;
// nothing loaded, a fresh start, everyone init model
return false;
return version_number;
}
}
/*!
@@ -185,8 +188,8 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
const utils::ISerializable *local_model) {
if (num_local_replica == 0) {
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
}
if (num_local_replica != 0) {
}
if (num_local_replica != 0) {
while (true) {
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
// save model model to new version place
@@ -516,7 +519,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
int recv_link,
const std::vector<bool> &req_in) {
RefLinkVector &links = tree_links;
// no need to run recovery for zero size message
// no need to run recovery for zero size messages
if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
const int nlink = static_cast<int>(links.size());
@@ -542,7 +545,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock);
finished = false;
}
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
@@ -620,16 +623,16 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
// check in local data
RecoverType role = requester ? kRequestData : kHaveData;
ReturnType succ;
if (false) {
if (num_local_replica != 0) {
if (requester) {
// clear existing history, if any, before load
local_rptr[local_chkpt_version].clear();
local_chkpt[local_chkpt_version].clear();
}
// recover local checkpoint
//succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
//m&local_chkpt[local_chkpt_version]);
//if (succ != kSuccess) return succ;
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
&local_chkpt[local_chkpt_version]);
if (succ != kSuccess) return succ;
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
// check if everyone is OK
unsigned state = 0;
@@ -818,8 +821,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
utils::LogPrintf("[%d] backward!!\n", rabit::GetRank());
if(false){// backward passing, passing state in backward direction of the ring
{// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1);
@@ -872,8 +874,6 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
utils::LogPrintf("[%d] FORward!!\n", rabit::GetRank());
{// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
@@ -937,7 +937,6 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
utils::LogPrintf("[%d] Finished!!\n", rabit::GetRank());
return kSuccess;
}
/*!