basic recovery works

This commit is contained in:
tqchen 2014-12-03 12:19:08 -08:00
parent 8a6768763d
commit 2523288509
6 changed files with 22 additions and 21 deletions

View File

@ -88,7 +88,7 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
* \brief connect to the master to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(void) {
void AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (master_uri == "NULL") {
rank = 0; return;
@ -105,7 +105,7 @@ void AllreduceBase::ReConnectLinks(void) {
utils::Check(magic == kMagic, "sync::Invalid master message, init failure");
utils::Assert(master.SendAll(&rank, sizeof(rank)) == sizeof(rank), "ReConnectLink failure 3");
master.SendStr(job_id);
master.SendStr(std::string("start"));
master.SendStr(std::string(cmd));
{// get new ranks
int newrank;
utils::Assert(master.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),

View File

@ -227,8 +227,9 @@ class AllreduceBase : public IEngine {
/*!
* \brief connect to the master to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to master
*/
void ReConnectLinks(void);
void ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*

View File

@ -281,14 +281,6 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
return kSuccess;
}
/*!
* \brief try to reconnect the broken links
* \return this function can kSuccess or kSockError
*/
AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) {
utils::Error("TryReConnectLinks: not implemented");
return kSuccess;
}
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported
@ -298,12 +290,20 @@ AllreduceRobust::ReturnType AllreduceRobust::TryReConnectLinks(void) {
*/
bool AllreduceRobust::CheckAndRecover(ReturnType err_type) {
if (err_type == kSuccess) return true;
// simple way, shutdown all links
for (size_t i = 0; i < links.size(); ++i) {
if (!links[i].sock.BadSocket()) links[i].sock.Close();
}
ReConnectLinks("recover");
return false;
// this was old way
while(err_type != kSuccess) {
switch(err_type) {
case kGetExcept: err_type = TryResetLinks(); break;
case kSockError: {
TryResetLinks();
err_type = TryReConnectLinks();
ReConnectLinks();
err_type = kSuccess;
break;
}
default: utils::Assert(false, "RecoverLinks: cannot reach here");

View File

@ -70,7 +70,7 @@ class AllreduceRobust : public AllreduceBase {
* this function is only used for test purpose
*/
virtual void InitAfterException(void) {
//this->CheckAndRecover(kGetExcept);
this->CheckAndRecover(kGetExcept);
}
private:
@ -234,12 +234,7 @@ class AllreduceRobust : public AllreduceBase {
* when kSockError is returned, it simply means there are bad sockets in the links,
* and some link recovery proceduer is needed
*/
ReturnType TryResetLinks(void);
/*!
* \brief try to reconnect the broken links
* \return this function can kSuccess or kSockError
*/
ReturnType TryReConnectLinks(void);
ReturnType TryResetLinks(void);
/*!
* \brief if err_type indicates an error
* recover links according to the error type reported

View File

@ -1,7 +1,7 @@
/*!
* \file rabit-inl.h
* \brief implementation of inline template function for rabit interface
*
*
* \author Tianqi Chen
*/
#ifndef RABIT_RABIT_INL_H

View File

@ -10,6 +10,7 @@ import os
import socket
import struct
import subprocess
import random
from threading import Thread
"""
@ -136,6 +137,7 @@ class Master:
wait_conn = {}
# set of nodes that is pending for getting up
todo_nodes = range(nslave)
random.shuffle(todo_nodes)
# maps job id to rank
job_map = {}
# list of workers that is pending to be assigned rank
@ -149,7 +151,10 @@ class Master:
assert s.rank not in wait_conn
shutdown[s.rank] = s
continue
assert s.cmd == 'start'
assert s.cmd == 'start' or s.cmd == 'recover'
if s.cmd == 'recover':
assert s.rank >= 0
print 'Recieve recover signal from %d' % s.rank
rank = s.decide_rank(job_map)
if rank == -1:
assert len(todo_nodes) != 0