nice fix, start check local check
This commit is contained in:
parent
3f22596e3c
commit
dbd05a65b5
@ -151,18 +151,21 @@ int AllreduceRobust::LoadCheckPoint(utils::ISerializable *global_model,
|
|||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0;
|
||||||
// load from buffer
|
// load from buffer
|
||||||
utils::MemoryBufferStream fs(&global_checkpoint);
|
utils::MemoryBufferStream fs(&global_checkpoint);
|
||||||
fs.Read(&version_number, sizeof(version_number));
|
if (global_checkpoint.length() == 0) {
|
||||||
if (version_number == 0) return version_number;
|
version_number = 0;
|
||||||
global_model->Load(fs);
|
} else {
|
||||||
|
utils::Assert(fs.Read(&version_number, sizeof(version_number)) != 0, "read in version number");
|
||||||
|
global_model->Load(fs);
|
||||||
|
}
|
||||||
// run another phase of check ack, if recovered from data
|
// run another phase of check ack, if recovered from data
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||||
"check ack must return true");
|
"check ack must return true");
|
||||||
return version_number;
|
return version_number;
|
||||||
} else {
|
} else {
|
||||||
// reset result buffer
|
// reset result buffer
|
||||||
resbuf.Clear(); seq_counter = 0;
|
resbuf.Clear(); seq_counter = 0; version_number = 0;
|
||||||
// nothing loaded, a fresh start, everyone init model
|
// nothing loaded, a fresh start, everyone init model
|
||||||
return false;
|
return version_number;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -185,8 +188,8 @@ void AllreduceRobust::CheckPoint(const utils::ISerializable *global_model,
|
|||||||
const utils::ISerializable *local_model) {
|
const utils::ISerializable *local_model) {
|
||||||
if (num_local_replica == 0) {
|
if (num_local_replica == 0) {
|
||||||
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
utils::Check(local_model == NULL, "need to set num_local_replica larger than 1 to checkpoint local_model");
|
||||||
}
|
}
|
||||||
if (num_local_replica != 0) {
|
if (num_local_replica != 0) {
|
||||||
while (true) {
|
while (true) {
|
||||||
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
|
if (RecoverExec(NULL, 0, 0, ActionSummary::kLocalCheckPoint)) break;
|
||||||
// save model model to new version place
|
// save model model to new version place
|
||||||
@ -516,7 +519,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
int recv_link,
|
int recv_link,
|
||||||
const std::vector<bool> &req_in) {
|
const std::vector<bool> &req_in) {
|
||||||
RefLinkVector &links = tree_links;
|
RefLinkVector &links = tree_links;
|
||||||
// no need to run recovery for zero size message
|
// no need to run recovery for zero size messages
|
||||||
if (links.size() == 0 || size == 0) return kSuccess;
|
if (links.size() == 0 || size == 0) return kSuccess;
|
||||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
||||||
const int nlink = static_cast<int>(links.size());
|
const int nlink = static_cast<int>(links.size());
|
||||||
@ -542,7 +545,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
if (i == recv_link && links[i].size_read != size) {
|
if (i == recv_link && links[i].size_read != size) {
|
||||||
selecter.WatchRead(links[i].sock);
|
selecter.WatchRead(links[i].sock);
|
||||||
finished = false;
|
finished = false;
|
||||||
}
|
}
|
||||||
if (req_in[i] && links[i].size_write != size) {
|
if (req_in[i] && links[i].size_write != size) {
|
||||||
if (role == kHaveData ||
|
if (role == kHaveData ||
|
||||||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
|
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
|
||||||
@ -620,16 +623,16 @@ AllreduceRobust::ReturnType AllreduceRobust::TryLoadCheckPoint(bool requester) {
|
|||||||
// check in local data
|
// check in local data
|
||||||
RecoverType role = requester ? kRequestData : kHaveData;
|
RecoverType role = requester ? kRequestData : kHaveData;
|
||||||
ReturnType succ;
|
ReturnType succ;
|
||||||
if (false) {
|
if (num_local_replica != 0) {
|
||||||
if (requester) {
|
if (requester) {
|
||||||
// clear existing history, if any, before load
|
// clear existing history, if any, before load
|
||||||
local_rptr[local_chkpt_version].clear();
|
local_rptr[local_chkpt_version].clear();
|
||||||
local_chkpt[local_chkpt_version].clear();
|
local_chkpt[local_chkpt_version].clear();
|
||||||
}
|
}
|
||||||
// recover local checkpoint
|
// recover local checkpoint
|
||||||
//succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
succ = TryRecoverLocalState(&local_rptr[local_chkpt_version],
|
||||||
//m&local_chkpt[local_chkpt_version]);
|
&local_chkpt[local_chkpt_version]);
|
||||||
//if (succ != kSuccess) return succ;
|
if (succ != kSuccess) return succ;
|
||||||
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
int nlocal = std::max(static_cast<int>(local_rptr[local_chkpt_version].size()) - 1, 0);
|
||||||
// check if everyone is OK
|
// check if everyone is OK
|
||||||
unsigned state = 0;
|
unsigned state = 0;
|
||||||
@ -818,8 +821,7 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
||||||
}
|
}
|
||||||
const int n = num_local_replica;
|
const int n = num_local_replica;
|
||||||
utils::LogPrintf("[%d] backward!!\n", rabit::GetRank());
|
{// backward passing, passing state in backward direction of the ring
|
||||||
if(false){// backward passing, passing state in backward direction of the ring
|
|
||||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||||
std::vector<int> msg_back(n + 1);
|
std::vector<int> msg_back(n + 1);
|
||||||
@ -872,8 +874,6 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
utils::LogPrintf("[%d] FORward!!\n", rabit::GetRank());
|
|
||||||
{// forward passing, passing state in forward direction of the ring
|
{// forward passing, passing state in forward direction of the ring
|
||||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||||
@ -937,7 +937,6 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
|||||||
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
utils::LogPrintf("[%d] Finished!!\n", rabit::GetRank());
|
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -26,6 +26,7 @@ class Model : public rabit::utils::ISerializable {
|
|||||||
fo.Write(data);
|
fo.Write(data);
|
||||||
}
|
}
|
||||||
virtual void InitModel(size_t n) {
|
virtual void InitModel(size_t n) {
|
||||||
|
data.clear();
|
||||||
data.resize(n, 1.0f);
|
data.resize(n, 1.0f);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -40,15 +41,15 @@ inline void TestMax(test::Mock &mock, Model *model, int ntrial, int iter) {
|
|||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
mock.Allreduce<op::Max>(&ndata[0], ndata.size());
|
mock.Allreduce<op::Max>(&ndata[0], ndata.size());
|
||||||
if (ntrial == iter && rank == 3) {
|
if (ntrial == 0 && rank == 3) {
|
||||||
// exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rmax = (i * 1) % z + model->data[i];
|
float rmax = (i * 1) % z + model->data[i];
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]);
|
rmax = std::max(rmax, (float)((i * (r+1)) % z) + model->data[i]);
|
||||||
}
|
}
|
||||||
utils::Check(rmax == ndata[i], "[%d] TestMax check failure", rank);
|
utils::Check(rmax == ndata[i], "[%d] TestMax check failurem i=%lu, rmax=%f, ndata=%f", rank, i, rmax, ndata[i]);
|
||||||
}
|
}
|
||||||
model->data = ndata;
|
model->data = ndata;
|
||||||
}
|
}
|
||||||
@ -62,12 +63,12 @@ inline void TestSum(test::Mock &mock, Model *model, int ntrial, int iter) {
|
|||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
ndata[i] = (i * (rank+1)) % z + model->data[i];
|
||||||
}
|
}
|
||||||
mock.Allreduce<op::Sum>(&ndata[0], ndata.size());
|
if (iter == 0 && ntrial==0 && rank == 0) {
|
||||||
|
|
||||||
if (ntrial == iter && rank == 0) {
|
|
||||||
throw MockException();
|
throw MockException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock.Allreduce<op::Sum>(&ndata[0], ndata.size());
|
||||||
|
|
||||||
for (size_t i = 0; i < ndata.size(); ++i) {
|
for (size_t i = 0; i < ndata.size(); ++i) {
|
||||||
float rsum = model->data[i] * nproc;
|
float rsum = model->data[i] * nproc;
|
||||||
for (int r = 0; r < nproc; ++r) {
|
for (int r = 0; r < nproc; ++r) {
|
||||||
@ -125,11 +126,11 @@ int main(int argc, char *argv[]) {
|
|||||||
for (int r = iter; r < 3; ++r) {
|
for (int r = iter; r < 3; ++r) {
|
||||||
TestMax(mock, &model, ntrial, r);
|
TestMax(mock, &model, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
utils::LogPrintf("[%d] !!!TestMax pass, iter=%d\n", rank, r);
|
||||||
int step = std::max(nproc / 3, 1);
|
//int step = std::max(nproc / 3, 1);
|
||||||
for (int i = 0; i < nproc; i += step) {
|
//for (int i = 0; i < nproc; i += step) {
|
||||||
TestBcast(mock, n, i, ntrial);
|
//TestBcast(mock, n, i, ntrial);
|
||||||
}
|
//}
|
||||||
utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
//utils::LogPrintf("[%d] !!!TestBcast pass, iter=%d\n", rank, r);
|
||||||
TestSum(mock, &model, ntrial, r);
|
TestSum(mock, &model, ntrial, r);
|
||||||
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
utils::LogPrintf("[%d] !!!TestSum pass, iter=%d\n", rank, r);
|
||||||
rabit::CheckPoint(&model);
|
rabit::CheckPoint(&model);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user