check in the ring passing
This commit is contained in:
parent
58f80c5675
commit
328cf187ba
@ -740,23 +740,121 @@ AllreduceRobust::TryRecoverLocalState(std::vector<size_t> *p_local_rptr,
|
||||
utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent");
|
||||
}
|
||||
const int n = num_local_replica;
|
||||
// message send to previous link
|
||||
{
|
||||
int msg_forward[2];
|
||||
int nlocal = static_cast<int>(rptr.size() - 1);
|
||||
msg_forward[0] = nlocal;
|
||||
utils::Assert(msg_forward[0] <= n, "invalid local replica");
|
||||
{// backward passing, passing state in backward direction of the ring
|
||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||
std::vector<int> msg_back(n + 1);
|
||||
msg_back[0] = nlocal;
|
||||
// backward passing one hop the request
|
||||
ReturnType succ = RingPassing(msg_forward,
|
||||
ReturnType succ;
|
||||
succ = RingPassing(BeginPtr(msg_back),
|
||||
1 * sizeof(int), (n+1) * sizeof(int),
|
||||
0 * sizeof(int), n * sizeof(int),
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) return succ;
|
||||
int msg_forward[2];
|
||||
msg_forward[0] = nlocal;
|
||||
succ = RingPassing(msg_forward,
|
||||
1 * sizeof(int), 2 * sizeof(int),
|
||||
0 * sizeof(int), 1 * sizeof(int),
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) return succ;
|
||||
// check how much current node can help with the request
|
||||
// if (nlocal > ) {
|
||||
|
||||
//}
|
||||
|
||||
// calculate the number of things we can read from next link
|
||||
int nread_end = nlocal;
|
||||
for (int i = 1; i <= n; ++i) {
|
||||
nread_end = std::max(nread_end, msg_back[i] - i);
|
||||
}
|
||||
// gives the size of forward
|
||||
int nwrite_start = std::min(msg_forward[1] + 1, nread_end);
|
||||
// get the size of each segments
|
||||
std::vector<size_t> sizes(nread_end);
|
||||
for (int i = 0; i < nlocal; ++i) {
|
||||
sizes[i] = rptr[i + 1] - rptr[i];
|
||||
}
|
||||
// pass size through the link
|
||||
succ = RingPassing(BeginPtr(sizes),
|
||||
nlocal * sizeof(size_t),
|
||||
nread_end * sizeof(size_t),
|
||||
nwrite_start * sizeof(size_t),
|
||||
nread_end * sizeof(size_t),
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) return succ;
|
||||
// update rptr
|
||||
rptr.resize(nread_end + 1);
|
||||
for (int i = nlocal; i < nread_end; ++i) {
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nread_end],
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) {
|
||||
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
||||
}
|
||||
}
|
||||
{// forward passing, passing state in forward direction of the ring
|
||||
const int nlocal = static_cast<int>(rptr.size() - 1);
|
||||
utils::Assert(nlocal <= n + 1, "invalid local replica");
|
||||
std::vector<int> msg_forward(n + 1);
|
||||
msg_forward[0] = nlocal;
|
||||
// backward passing one hop the request
|
||||
ReturnType succ;
|
||||
succ = RingPassing(BeginPtr(msg_forward),
|
||||
1 * sizeof(int), (n+1) * sizeof(int),
|
||||
0 * sizeof(int), n * sizeof(int),
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) return succ;
|
||||
int msg_back[2];
|
||||
msg_back[0] = nlocal;
|
||||
succ = RingPassing(msg_back,
|
||||
1 * sizeof(int), 2 * sizeof(int),
|
||||
0 * sizeof(int), 1 * sizeof(int),
|
||||
ring_next, ring_prev);
|
||||
if (succ != kSuccess) return succ;
|
||||
// calculate the number of things we can read from next link
|
||||
int nread_end = nlocal, nwrite_end = 1;
|
||||
// have to have itself in order to get other data from prev link
|
||||
if (nlocal != 0) {
|
||||
for (int i = 1; i <= n; ++i) {
|
||||
if (msg_forward[i] == 0) break;
|
||||
nread_end = std::max(nread_end, i + 1);
|
||||
nwrite_end = i + 1;
|
||||
}
|
||||
if (nwrite_end > n) nwrite_end = n;
|
||||
} else {
|
||||
nread_end = 0; nwrite_end = 0;
|
||||
}
|
||||
// gives the size of forward
|
||||
int nwrite_start = std::min(msg_back[1] - 1, nwrite_end);
|
||||
// next node miss the state of itself, cannot recover
|
||||
if (nwrite_start < 0) nwrite_start = nwrite_end = 0;
|
||||
// get the size of each segments
|
||||
std::vector<size_t> sizes(nread_end);
|
||||
for (int i = 0; i < nlocal; ++i) {
|
||||
sizes[i] = rptr[i + 1] - rptr[i];
|
||||
}
|
||||
// pass size through the link, check consistency
|
||||
succ = RingPassing(BeginPtr(sizes),
|
||||
nlocal * sizeof(size_t),
|
||||
nread_end * sizeof(size_t),
|
||||
nwrite_start * sizeof(size_t),
|
||||
nwrite_end * sizeof(size_t),
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) return succ;
|
||||
// update rptr
|
||||
rptr.resize(nread_end + 1);
|
||||
for (int i = nlocal; i < nread_end; ++i) {
|
||||
rptr[i + 1] = rptr[i] + sizes[i];
|
||||
}
|
||||
chkpt.resize(rptr.back());
|
||||
// pass data through the link
|
||||
succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end],
|
||||
rptr[nwrite_start], rptr[nwrite_end],
|
||||
ring_prev, ring_next);
|
||||
if (succ != kSuccess) {
|
||||
rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ;
|
||||
}
|
||||
}
|
||||
return kSuccess;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user