check in the ring passing

This commit is contained in:
tqchen 2014-12-06 23:00:10 -08:00
parent 58f80c5675
commit 328cf187ba

View File

@ -740,23 +740,121 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
}
const int n = num_local_replica;
// message send to previous link
{
int msg_forward[2];
int nlocal = static_cast<int>(rptr.size() - 1);
msg_forward[0] = nlocal;
utils::Assert(msg_forward[0] <= n, "invalid local replica");
{// backward passing, passing state in backward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_back(n + 1);
msg_back[0] = nlocal;
// backward passing one hop the request
ReturnType succ = RingPassing(msg_forward,
ReturnType succ;
succ = RingPassing(BeginPtr(msg_back),
1 * sizeof(int), (n+1) * sizeof(int),
0 * sizeof(int), n * sizeof(int),
ring_next, ring_prev);
if (succ != kSuccess) return succ;
int msg_forward[2];
msg_forward[0] = nlocal;
succ = RingPassing(msg_forward,
1 * sizeof(int), 2 * sizeof(int),
0 * sizeof(int), 1 * sizeof(int),
ring_prev, ring_next);
if (succ != kSuccess) return succ;
// check how much current node can help with the request
// if (nlocal > ) {
//}
// calculate the number of things we can read from next link
int nread_end = nlocal;
for (int i = 1; i <= n; ++i) {
nread_end = std::max(nread_end, msg_back[i] - i);
}
// gives the size of forward
int nwrite_start = std::min(msg_forward[1] + 1, nread_end);
// get the size of each segments
std::vector<size_t> sizes(nread_end);
for (int i = 0; i < nlocal; ++i) {
sizes[i] = rptr[i + 1] - rptr[i];
}
// pass size through the link
succ = RingPassing(BeginPtr(sizes),
nlocal * sizeof(size_t),
nread_end * sizeof(size_t),
nwrite_start * sizeof(size_t),
nread_end * sizeof(size_t),
ring_next, ring_prev);
if (succ != kSuccess) return succ;
// update rptr
rptr.resize(nread_end + 1);
for (int i = nlocal; i < nread_end; ++i) {
rptr[i + 1] = rptr[i] + sizes[i];
}
chkpt.resize(rptr.back());
// pass data through the link
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nread_end],
ring_next, ring_prev);
if (succ != kSuccess) {
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
{// forward passing, passing state in forward direction of the ring
const int nlocal = static_cast<int>(rptr.size() - 1);
utils::Assert(nlocal <= n + 1, "invalid local replica");
std::vector<int> msg_forward(n + 1);
msg_forward[0] = nlocal;
// backward passing one hop the request
ReturnType succ;
succ = RingPassing(BeginPtr(msg_forward),
1 * sizeof(int), (n+1) * sizeof(int),
0 * sizeof(int), n * sizeof(int),
ring_prev, ring_next);
if (succ != kSuccess) return succ;
int msg_back[2];
msg_back[0] = nlocal;
succ = RingPassing(msg_back,
1 * sizeof(int), 2 * sizeof(int),
0 * sizeof(int), 1 * sizeof(int),
ring_next, ring_prev);
if (succ != kSuccess) return succ;
// calculate the number of things we can read from next link
int nread_end = nlocal, nwrite_end = 1;
// have to have itself in order to get other data from prev link
if (nlocal != 0) {
for (int i = 1; i <= n; ++i) {
if (msg_forward[i] == 0) break;
nread_end = std::max(nread_end, i + 1);
nwrite_end = i + 1;
}
if (nwrite_end > n) nwrite_end = n;
} else {
nread_end = 0; nwrite_end = 0;
}
// gives the size of forward
int nwrite_start = std::min(msg_back[1] - 1, nwrite_end);
// next node miss the state of itself, cannot recover
if (nwrite_start < 0) nwrite_start = nwrite_end = 0;
// get the size of each segments
std::vector<size_t> sizes(nread_end);
for (int i = 0; i < nlocal; ++i) {
sizes[i] = rptr[i + 1] - rptr[i];
}
// pass size through the link, check consistency
succ = RingPassing(BeginPtr(sizes),
nlocal * sizeof(size_t),
nread_end * sizeof(size_t),
nwrite_start * sizeof(size_t),
nwrite_end * sizeof(size_t),
ring_prev, ring_next);
if (succ != kSuccess) return succ;
// update rptr
rptr.resize(nread_end + 1);
for (int i = nlocal; i < nread_end; ++i) {
rptr[i + 1] = rptr[i] + sizes[i];
}
chkpt.resize(rptr.back());
// pass data through the link
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
rptr[nwrite_start], rptr[nwrite_end],
ring_prev, ring_next);
if (succ != kSuccess) {
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
}
}
return kSuccess;
}