From 328cf187bad60f8f04e1e872b79215beb89e8828 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 6 Dec 2014 23:00:10 -0800 Subject: [PATCH] check in the ring passing --- src/allreduce_robust.cc | 128 +++++++++++++++++++++++++++++++++++----- 1 file changed, 113 insertions(+), 15 deletions(-) diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index eace31bb6..dbb318c33 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -740,23 +740,121 @@ AllreduceRobust::TryRecoverLocalState(std::vector *p_local_rptr, utils::Assert(chkpt.length() == 0, "local chkpt space inconsistent"); } const int n = num_local_replica; - // message send to previous link - { - int msg_forward[2]; - int nlocal = static_cast(rptr.size() - 1); - msg_forward[0] = nlocal; - utils::Assert(msg_forward[0] <= n, "invalid local replica"); + {// backward passing, passing state in backward direction of the ring + const int nlocal = static_cast(rptr.size() - 1); + utils::Assert(nlocal <= n + 1, "invalid local replica"); + std::vector msg_back(n + 1); + msg_back[0] = nlocal; // backward passing one hop the request - ReturnType succ = RingPassing(msg_forward, - 1 * sizeof(int), 2 * sizeof(int), - 0 * sizeof(int), 1 * sizeof(int), - ring_prev, ring_next); + ReturnType succ; + succ = RingPassing(BeginPtr(msg_back), + 1 * sizeof(int), (n+1) * sizeof(int), + 0 * sizeof(int), n * sizeof(int), + ring_next, ring_prev); if (succ != kSuccess) return succ; - // check how much current node can help with the request - // if (nlocal > ) { - - //} - + int msg_forward[2]; + msg_forward[0] = nlocal; + succ = RingPassing(msg_forward, + 1 * sizeof(int), 2 * sizeof(int), + 0 * sizeof(int), 1 * sizeof(int), + ring_prev, ring_next); + if (succ != kSuccess) return succ; + // calculate the number of things we can read from next link + int nread_end = nlocal; + for (int i = 1; i <= n; ++i) { + nread_end = std::max(nread_end, msg_back[i] - i); + } + // gives the size of forward + int nwrite_start = std::min(msg_forward[1] + 1, nread_end); + // get the size of each segments + std::vector sizes(nread_end); + for (int i = 0; i < nlocal; ++i) { + sizes[i] = rptr[i + 1] - rptr[i]; + } + // pass size through the link + succ = RingPassing(BeginPtr(sizes), + nlocal * sizeof(size_t), + nread_end * sizeof(size_t), + nwrite_start * sizeof(size_t), + nread_end * sizeof(size_t), + ring_next, ring_prev); + if (succ != kSuccess) return succ; + // update rptr + rptr.resize(nread_end + 1); + for (int i = nlocal; i < nread_end; ++i) { + rptr[i + 1] = rptr[i] + sizes[i]; + } + chkpt.resize(rptr.back()); + // pass data through the link + succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end], + rptr[nwrite_start], rptr[nread_end], + ring_next, ring_prev); + if (succ != kSuccess) { + rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; + } + } + {// forward passing, passing state in forward direction of the ring + const int nlocal = static_cast(rptr.size() - 1); + utils::Assert(nlocal <= n + 1, "invalid local replica"); + std::vector msg_forward(n + 1); + msg_forward[0] = nlocal; + // backward passing one hop the request + ReturnType succ; + succ = RingPassing(BeginPtr(msg_forward), + 1 * sizeof(int), (n+1) * sizeof(int), + 0 * sizeof(int), n * sizeof(int), + ring_prev, ring_next); + if (succ != kSuccess) return succ; + int msg_back[2]; + msg_back[0] = nlocal; + succ = RingPassing(msg_back, + 1 * sizeof(int), 2 * sizeof(int), + 0 * sizeof(int), 1 * sizeof(int), + ring_next, ring_prev); + if (succ != kSuccess) return succ; + // calculate the number of things we can read from next link + int nread_end = nlocal, nwrite_end = 1; + // have to have itself in order to get other data from prev link + if (nlocal != 0) { + for (int i = 1; i <= n; ++i) { + if (msg_forward[i] == 0) break; + nread_end = std::max(nread_end, i + 1); + nwrite_end = i + 1; + } + if (nwrite_end > n) nwrite_end = n; + } else { + nread_end = 0; nwrite_end = 0; + } + // gives the size of forward + int nwrite_start = std::min(msg_back[1] - 1, nwrite_end); + // next node miss the state of itself, cannot recover + if (nwrite_start < 0) nwrite_start = nwrite_end = 0; + // get the size of each segments + std::vector sizes(nread_end); + for (int i = 0; i < nlocal; ++i) { + sizes[i] = rptr[i + 1] - rptr[i]; + } + // pass size through the link, check consistency + succ = RingPassing(BeginPtr(sizes), + nlocal * sizeof(size_t), + nread_end * sizeof(size_t), + nwrite_start * sizeof(size_t), + nwrite_end * sizeof(size_t), + ring_prev, ring_next); + if (succ != kSuccess) return succ; + // update rptr + rptr.resize(nread_end + 1); + for (int i = nlocal; i < nread_end; ++i) { + rptr[i + 1] = rptr[i] + sizes[i]; + } + chkpt.resize(rptr.back()); + // pass data through the link + succ = RingPassing(&chkpt[0], rptr[nlocal], rptr[nread_end], + rptr[nwrite_start], rptr[nwrite_end], + ring_prev, ring_next); + if (succ != kSuccess) { + rptr.resize(nlocal + 1); chkpt.resize(rptr.back()); return succ; + } } return kSuccess; }