diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index f90e5d621..4c30b62d2 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -88,7 +88,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) { * \brief connect to the master to fix the the missing links * this function is also used when the engine start up */ -void AllreduceBase::ReConnectLinks(void) { +void AllreduceBase::ReConnectLinks(const char *cmd) { // single node mode if (master_uri == "NULL") { rank = 0; return; @@ -105,7 +105,7 @@ void AllreduceBase::ReConnectLinks(void) { utils::Check(magic == kMagic, "sync::Invalid master message, init failure"); utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3"); master.SendStr(job_id); - master.SendStr(std::string("start")); + master.SendStr(std::string(cmd)); {// get new ranks int newrank; utils::Assert(master.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), diff --git a/src/allreduce_base.h b/src/allreduce_base.h index 436916cda..cd9a5b0d0 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -227,8 +227,9 @@ class AllreduceBase : public IEngine { /*! * \brief connect to the master to fix the the missing links * this function is also used when the engine start up + * \param cmd possible command to sent to master */ - void ReConnectLinks(void); + void ReConnectLinks(const char *cmd = "start"); /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index d2339a3be..6aba63e82 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -281,14 +281,6 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) { } return kSuccess; } -/*! - * \brief try to reconnect the broken links - * \return this function can kSuccess or kSockError - */ -AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) { - utils::Error("TryReConnectLinks: not implemented"); - return kSuccess; -} /*! * \brief if err_type indicates an error * recover links according to the error type reported @@ -298,12 +290,20 @@ AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) { */ bool AllreduceRobust::CheckAndRecover(ReturnType err_type) { if (err_type == kSuccess) return true; + // simple way, shutdown all links + for (size_t i = 0; i < links.size(); ++i) { + if (!links[i].sock.BadSocket()) links[i].sock.Close(); + } + ReConnectLinks("recover"); + return false; + // this was old way while(err_type != kSuccess) { switch(err_type) { case kGetExcept: err_type = TryResetLinks(); break; case kSockError: { TryResetLinks(); - err_type = TryReConnectLinks(); + ReConnectLinks(); + err_type = kSuccess; break; } default: utils::Assert(false, "RecoverLinks: cannot reach here"); diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index 26e45f16c..ad660da94 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -70,7 +70,7 @@ class AllreduceRobust : public AllreduceBase { * this function is only used for test purpose */ virtual void InitAfterException(void) { - //this->CheckAndRecover(kGetExcept); + this->CheckAndRecover(kGetExcept); } private: @@ -234,12 +234,7 @@ class AllreduceRobust : public AllreduceBase { * when kSockError is returned, it simply means there are bad sockets in the links, * and some link recovery proceduer is needed */ - ReturnType TryResetLinks(void); - /*! - * \brief try to reconnect the broken links - * \return this function can kSuccess or kSockError - */ - ReturnType TryReConnectLinks(void); + ReturnType TryResetLinks(void); /*! * \brief if err_type indicates an error * recover links according to the error type reported diff --git a/src/rabit-inl.h b/src/rabit-inl.h index 631686582..f3fd39b2a 100644 --- a/src/rabit-inl.h +++ b/src/rabit-inl.h @@ -1,7 +1,7 @@ /*! * \file rabit-inl.h * \brief implementation of inline template function for rabit interface - * + * * \author Tianqi Chen */ #ifndef RABIT_RABIT_INL_H diff --git a/src/rabit_master.py b/src/rabit_master.py index 1cfc00dc0..cfa1cce9a 100644 --- a/src/rabit_master.py +++ b/src/rabit_master.py @@ -10,6 +10,7 @@ import os import socket import struct import subprocess +import random from threading import Thread """ @@ -136,6 +137,7 @@ class Master: wait_conn = {} # set of nodes that is pending for getting up todo_nodes = range(nslave) + random.shuffle(todo_nodes) # maps job id to rank job_map = {} # list of workers that is pending to be assigned rank @@ -149,7 +151,10 @@ class Master: assert s.rank not in wait_conn shutdown[s.rank] = s continue - assert s.cmd == 'start' + assert s.cmd == 'start' or s.cmd == 'recover' + if s.cmd == 'recover': + assert s.rank >= 0 + print 'Recieve recover signal from %d' % s.rank rank = s.decide_rank(job_map) if rank == -1: assert len(todo_nodes) != 0