add keepalive to socket, fix recover problem when a node is requester and pass data

This commit is contained in:
tqchen 2014-12-21 17:55:08 -08:00
parent cfea4dbe85
commit a624051b85
5 changed files with 30 additions and 13 deletions

View File

@ -252,8 +252,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
tree_links.plinks.clear();
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode
all_links[i].sock.SetNonBlock(true);
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());

View File

@ -40,17 +40,17 @@ class AllreduceMock : public AllreduceRobust {
ReduceFunction reducer,
PreprocFunction prepare_fun,
void *prepare_arg) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
count, reducer, prepare_fun, prepare_arg);
}
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
}
virtual void CheckPoint(const ISerializable *global_model,
const ISerializable *local_model) {
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
AllreduceRobust::CheckPoint(global_model, local_model);
}
@ -82,10 +82,10 @@ class AllreduceMock : public AllreduceRobust {
// record all mock actions
std::map<MockKey, int> mock_map;
// used to generate all kinds of exceptions
inline void Verify(const MockKey &key) {
inline void Verify(const MockKey &key, const char *name) {
if (mock_map.count(key) != 0) {
num_trial += 1;
utils::Error("[%d]@@@Hit Mock Error", rank);
utils::Error("[%d]@@@Hit Mock Error:%s", rank, name);
}
}
};

View File

@ -431,7 +431,7 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
if (dist_in[i].first + 1 < res) {
res = dist_in[i].first + 1;
size = dist_in[i].second;
}
}
}
// add one hop
@ -575,7 +575,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
(links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
}
finished = false;
@ -728,10 +728,17 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
}
int recv_link;
std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
// size of data
size_t data_size = size;
ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in);
if (succ != kSuccess) return succ;
utils::Check(size != 0, "zero size check point is not allowed");
return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in);
utils::Check(data_size != 0, "zero size check point is not allowed");
if (role == kRequestData || role == kHaveData) {
utils::Check(data_size == size,
"Allreduce Recovered data size do not match the specification of function call\n"\
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
}
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
}
/*!
* \brief try to run recover execution for a request action described by flag and seqno,

View File

@ -219,6 +219,16 @@ class TCPSocket : public Socket{
}
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
}
/*!
* \brief enable/disable TCP keepalive
* \param keepalive whether to set the keep alive option on
*/
inline void SetKeepAlive(bool keepalive) {
int opt = static_cast<int>(keepalive);
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
Socket::Error("SetKeepAlive");
}
}
/*!
* \brief create the socket, call this before using socket
* \param af domain

View File

@ -151,4 +151,3 @@ int main(int argc, char *argv[]) {
rabit::Finalize();
return 0;
}