add keepalive to socket, fix recover problem when a node is requester and pass data
This commit is contained in:
parent
cfea4dbe85
commit
a624051b85
@ -252,8 +252,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
tree_links.plinks.clear();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
|
||||
@ -40,17 +40,17 @@ class AllreduceMock : public AllreduceRobust {
|
||||
ReduceFunction reducer,
|
||||
PreprocFunction prepare_fun,
|
||||
void *prepare_arg) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||
count, reducer, prepare_fun, prepare_arg);
|
||||
}
|
||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||
}
|
||||
virtual void CheckPoint(const ISerializable *global_model,
|
||||
const ISerializable *local_model) {
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||
}
|
||||
|
||||
@ -82,10 +82,10 @@ class AllreduceMock : public AllreduceRobust {
|
||||
// record all mock actions
|
||||
std::map<MockKey, int> mock_map;
|
||||
// used to generate all kinds of exceptions
|
||||
inline void Verify(const MockKey &key) {
|
||||
inline void Verify(const MockKey &key, const char *name) {
|
||||
if (mock_map.count(key) != 0) {
|
||||
num_trial += 1;
|
||||
utils::Error("[%d]@@@Hit Mock Error", rank);
|
||||
utils::Error("[%d]@@@Hit Mock Error:%s", rank, name);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -431,7 +431,7 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||
if (dist_in[i].first + 1 < res) {
|
||||
res = dist_in[i].first + 1;
|
||||
size = dist_in[i].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
// add one hop
|
||||
|
||||
@ -575,7 +575,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
||||
}
|
||||
if (req_in[i] && links[i].size_write != size) {
|
||||
if (role == kHaveData ||
|
||||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
|
||||
(links[recv_link].size_read != links[i].size_write)) {
|
||||
selecter.WatchWrite(links[i].sock);
|
||||
}
|
||||
finished = false;
|
||||
@ -728,10 +728,17 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
}
|
||||
int recv_link;
|
||||
std::vector<bool> req_in;
|
||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
// size of data
|
||||
size_t data_size = size;
|
||||
ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in);
|
||||
if (succ != kSuccess) return succ;
|
||||
utils::Check(size != 0, "zero size check point is not allowed");
|
||||
return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in);
|
||||
utils::Check(data_size != 0, "zero size check point is not allowed");
|
||||
if (role == kRequestData || role == kHaveData) {
|
||||
utils::Check(data_size == size,
|
||||
"Allreduce Recovered data size do not match the specification of function call\n"\
|
||||
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
|
||||
}
|
||||
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
|
||||
}
|
||||
/*!
|
||||
* \brief try to run recover execution for a request action described by flag and seqno,
|
||||
|
||||
10
src/socket.h
10
src/socket.h
@ -219,6 +219,16 @@ class TCPSocket : public Socket{
|
||||
}
|
||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||
}
|
||||
/*!
|
||||
* \brief enable/disable TCP keepalive
|
||||
* \param keepalive whether to set the keep alive option on
|
||||
*/
|
||||
inline void SetKeepAlive(bool keepalive) {
|
||||
int opt = static_cast<int>(keepalive);
|
||||
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
|
||||
Socket::Error("SetKeepAlive");
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief create the socket, call this before using socket
|
||||
* \param af domain
|
||||
|
||||
@ -151,4 +151,3 @@ int main(int argc, char *argv[]) {
|
||||
rabit::Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user