diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 6ef941425..f90e5d621 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -55,7 +55,8 @@ void AllreduceBase::Shutdown(void) { utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); master.SendStr(job_id); - master.SendStr(std::string("shutdown")); + master.SendStr(std::string("shutdown")); + master.Close(); utils::TCPSocket::Finalize(); } /*! @@ -102,7 +103,6 @@ void AllreduceBase::ReConnectLinks(void) { utils::Assert(master.SendAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 1"); utils::Assert(master.RecvAll(&magic, sizeof(magic)) == sizeof(magic), "ReConnectLink failure 2"); utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); - utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); master.SendStr(job_id); master.SendStr(std::string("start")); @@ -112,10 +112,11 @@ void AllreduceBase::ReConnectLinks(void) { "ReConnectLink failure 4"); utils::Assert(master.RecvAll(&parent_rank, sizeof(parent_rank)) == sizeof(parent_rank), "ReConnectLink failure 4"); + utils::Assert(master.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), + "ReConnectLink failure 4"); utils::Assert(rank == -1 || newrank == rank, "must keep rank to same if the node already have one"); rank = newrank; - } - + } // create listening socket utils::TCPSocket sock_listen; sock_listen.Create(); @@ -125,7 +126,6 @@ void AllreduceBase::ReConnectLinks(void) { // get number of to connect and number of to accept nodes from master int num_conn, num_accept, num_error = 1; - do { // send over good links std::vector good_link; @@ -146,7 +146,7 @@ void AllreduceBase::ReConnectLinks(void) { utils::Assert(master.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), "ReConnectLink failure 7"); utils::Assert(master.RecvAll(&num_accept, sizeof(num_accept)) == sizeof(num_accept), - "ReConnectLink failure 8"); + "ReConnectLink failure 8"); num_error = 0; for (int i = 0; i < num_conn; ++i) { LinkRecord r; @@ -202,7 +202,6 @@ void AllreduceBase::ReConnectLinks(void) { links[i].sock.SetNonBlock(true); if (links[i].rank == parent_rank) parent_index = static_cast(i); } - utils::LogPrintf("[%d] parent_rank=%d, parent_index=%d, nlink=%d\n", rank, parent_rank, parent_index, (int)links.size()); if (parent_rank != -1) { utils::Assert(parent_index != -1, "cannot find parent in the link"); } diff --git a/src/allreduce_base.h b/src/allreduce_base.h index d5172f9f7..436916cda 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -80,7 +80,7 @@ class AllreduceBase : public IEngine { */ virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { utils::Assert(TryBroadcast(sendrecvbuf_, total_size, root) == kSuccess, - "Allreduce failed"); + "Broadcast failed"); } /*! * \brief load latest check point diff --git a/src/rabit-inl.h b/src/rabit-inl.h index c38766bbb..631686582 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -103,7 +103,7 @@ inline void Broadcast(std::vector *sendrecv_data, int root) { sendrecv_data->resize(size); } if (size != 0) { - Broadcast(&sendrecv_data[0], size * sizeof(DType), root); + Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root); } } inline void Broadcast(std::string *sendrecv_data, int root) { @@ -113,7 +113,7 @@ inline void Broadcast(std::string *sendrecv_data, int root) { sendrecv_data->resize(size); } if (size != 0) { - Broadcast(&sendrecv_data[0], size * sizeof(char), root); + Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root); } } diff --git a/src/rabit_master.py b/src/rabit_master.py index 85b981972..1cfc00dc0 100644 --- a/src/rabit_master.py +++ b/src/rabit_master.py @@ -77,6 +77,8 @@ class SlaveEntry: self.sock.sendint(rank) # send parent rank self.sock.sendint((rank + 1) / 2 - 1) + # send world size + self.sock.sendint(nslave) while True: ngood = self.sock.recvint() goodset = set([]) @@ -88,8 +90,6 @@ class SlaveEntry: for r in badset: if r in wait_conn: conset.append(r) - print 'rank=%d' % rank - print 'conset=%s' % str(conset) self.sock.sendint(len(conset)) self.sock.sendint(len(badset) - len(conset)) for r in conset: @@ -109,7 +109,6 @@ class SlaveEntry: for r in rmset: wait_conn.pop(r, None) self.wait_accept = len(badset) - len(conset) - print 'wait=%d' % self.wait_accept return rmset class Master: diff --git a/test/test_allreduce.cpp b/test/test_allreduce.cpp index 625d9592a..707b1a22a 100644 --- a/test/test_allreduce.cpp +++ b/test/test_allreduce.cpp @@ -80,7 +80,7 @@ int main(int argc, char *argv[]) { TestSum(mock, n); utils::LogPrintf("[%d] !!!TestSum pass\n", rank); int step = std::max(nproc / 3, 1); - for (int i = 0; i < nproc; i += step) { + for (int i = 0; i < nproc; i += step) { TestBcast(mock, n, i); } utils::LogPrintf("[%d] !!!TestBcast pass\n", rank);