diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index f2a104a90..6fb4a11cb 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -252,8 +252,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { tree_links.plinks.clear(); for (size_t i = 0; i < all_links.size(); ++i) { utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); - // set the socket to non-blocking mode - all_links[i].sock.SetNonBlock(true); + // set the socket to non-blocking mode, enable TCP keepalive + all_links[i].sock.SetNonBlock(true); + all_links[i].sock.SetKeepAlive(true); if (tree_neighbors.count(all_links[i].rank) != 0) { if (all_links[i].rank == parent_rank) { parent_index = static_cast(tree_links.plinks.size()); diff --git a/src/allreduce_mock.h b/src/allreduce_mock.h index 36d760b70..f46ee6885 100644 --- a/src/allreduce_mock.h +++ b/src/allreduce_mock.h @@ -40,17 +40,17 @@ class AllreduceMock : public AllreduceRobust { ReduceFunction reducer, PreprocFunction prepare_fun, void *prepare_arg) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial)); + this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce"); AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes, count, reducer, prepare_fun, prepare_arg); } virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial)); + this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast"); AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root); } virtual void CheckPoint(const ISerializable *global_model, const ISerializable *local_model) { - this->Verify(MockKey(rank, version_number, seq_counter, num_trial)); + this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint"); AllreduceRobust::CheckPoint(global_model, local_model); } @@ -82,10 +82,10 @@ class AllreduceMock : public AllreduceRobust { // record all mock actions std::map mock_map; // used to generate all kinds of exceptions - inline void Verify(const MockKey &key) { + inline void Verify(const MockKey &key, const char *name) { if (mock_map.count(key) != 0) { num_trial += 1; - utils::Error("[%d]@@@Hit Mock Error", rank); + utils::Error("[%d]@@@Hit Mock Error:%s", rank, name); } } }; diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index fe8013cb6..fb53a0777 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -431,7 +431,7 @@ ShortestDist(const std::pair &node_value, if (dist_in[i].first + 1 < res) { res = dist_in[i].first + 1; size = dist_in[i].second; - } + } } // add one hop @@ -575,7 +575,7 @@ AllreduceRobust::TryRecoverData(RecoverType role, } if (req_in[i] && links[i].size_write != size) { if (role == kHaveData || - (role == kPassData && links[recv_link].size_read != links[i].size_write)) { + (links[recv_link].size_read != links[i].size_write)) { selecter.WatchWrite(links[i].sock); } finished = false; @@ -728,10 +728,17 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re } int recv_link; std::vector req_in; - ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); + // size of data + size_t data_size = size; + ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in); if (succ != kSuccess) return succ; - utils::Check(size != 0, "zero size check point is not allowed"); - return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in); + utils::Check(data_size != 0, "zero size check point is not allowed"); + if (role == kRequestData || role == kHaveData) { + utils::Check(data_size == size, + "Allreduce Recovered data size do not match the specification of function call\n"\ + "Please check if calling sequence of recovered program is the same the original one in current VersionNumber"); + } + return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in); } /*! * \brief try to run recover execution for a request action described by flag and seqno, diff --git a/src/socket.h b/src/socket.h index 899ab03a7..29d62db35 100644 --- a/src/socket.h +++ b/src/socket.h @@ -219,6 +219,16 @@ class TCPSocket : public Socket{ } explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) { } + /*! + * \brief enable/disable TCP keepalive + * \param keepalive whether to set the keep alive option on + */ + inline void SetKeepAlive(bool keepalive) { + int opt = static_cast(keepalive); + if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) { + Socket::Error("SetKeepAlive"); + } + } /*! * \brief create the socket, call this before using socket * \param af domain diff --git a/toolkit/kmeans.cpp b/toolkit/kmeans.cpp index 3a55a0427..11f191625 100644 --- a/toolkit/kmeans.cpp +++ b/toolkit/kmeans.cpp @@ -151,4 +151,3 @@ int main(int argc, char *argv[]) { rabit::Finalize(); return 0; } -