diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 841525bc2..0235723a6 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -448,7 +448,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_, // read data from childs for (int i = 0; i < nlink; ++i) { if (i != parent_index && selecter.CheckRead(links[i].sock)) { - ReturnType ret = links[i].ReadToRingBuffer(size_up_out); + ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size); if (ret != kSuccess) { return ReportError(&links[i], ret); } @@ -778,13 +778,13 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_, if (finished) break; selecter.Select(); if (read_ptr != stop_read && selecter.CheckRead(next.sock)) { - ReturnType ret = next.ReadToRingBuffer(reduce_ptr); + ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read); if (ret != kSuccess) { return ReportError(&next, ret); } // sync the rate read_ptr = next.size_read; - utils::Assert(read_ptr <= stop_read, "read_ptr boundary check"); + utils::Assert(read_ptr <= stop_read, "[%d] read_ptr boundary check", rank); const size_t buffer_size = next.buffer_size; size_t max_reduce = (read_ptr / type_nbytes) * type_nbytes; while (reduce_ptr < max_reduce) { diff --git a/src/allreduce_base.h b/src/allreduce_base.h index af4c7cfdc..a9eafea39 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -278,15 +278,19 @@ class AllreduceBase : public IEngine { * \brief read data into ring-buffer, with care not to existing useful override data * position after protect_start * \param protect_start all data start from protect_start is still needed in buffer - * read shall not override this + * read shall not override this + * \param max_size_read maximum logical amount we can read, size_read cannot exceed this value * \return the type of reading */ - inline ReturnType ReadToRingBuffer(size_t protect_start) { + inline ReturnType ReadToRingBuffer(size_t protect_start, size_t max_size_read) { utils::Assert(buffer_head != NULL, "ReadToRingBuffer: buffer not allocated"); + utils::Assert(size_read <= max_size_read, "ReadToRingBuffer: max_size_read check"); size_t ngap = size_read - protect_start; utils::Assert(ngap <= buffer_size, "Allreduce: boundary check"); size_t offset = size_read % buffer_size; - size_t nmax = std::min(buffer_size - ngap, buffer_size - offset); + size_t nmax = max_size_read - size_read; + nmax = std::min(nmax, buffer_size - ngap); + nmax = std::min(nmax, buffer_size - offset); if (nmax == 0) return kSuccess; ssize_t len = sock.Recv(buffer_head + offset, nmax); // length equals 0, remote disconnected diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index 040cc1134..84abaceba 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, if (req_in[i]) min_write = std::min(links[i].size_write, min_write); } utils::Assert(min_write <= links[pid].size_read, "boundary check"); - ReturnType ret = links[pid].ReadToRingBuffer(min_write); + ReturnType ret = links[pid].ReadToRingBuffer(min_write, size); if (ret != kSuccess) { return ReportError(&links[pid], ret); }