add keepalive to socket, fix recover problem when a node is requester and pass data
This commit is contained in:
parent
cfea4dbe85
commit
a624051b85
@ -252,8 +252,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
|||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_links[i].sock.SetNonBlock(true);
|
all_links[i].sock.SetNonBlock(true);
|
||||||
|
all_links[i].sock.SetKeepAlive(true);
|
||||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||||
if (all_links[i].rank == parent_rank) {
|
if (all_links[i].rank == parent_rank) {
|
||||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||||
|
|||||||
@ -40,17 +40,17 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
ReduceFunction reducer,
|
ReduceFunction reducer,
|
||||||
PreprocFunction prepare_fun,
|
PreprocFunction prepare_fun,
|
||||||
void *prepare_arg) {
|
void *prepare_arg) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "AllReduce");
|
||||||
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
AllreduceRobust::Allreduce(sendrecvbuf_, type_nbytes,
|
||||||
count, reducer, prepare_fun, prepare_arg);
|
count, reducer, prepare_fun, prepare_arg);
|
||||||
}
|
}
|
||||||
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
virtual void Broadcast(void *sendrecvbuf_, size_t total_size, int root) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "Broadcast");
|
||||||
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
AllreduceRobust::Broadcast(sendrecvbuf_, total_size, root);
|
||||||
}
|
}
|
||||||
virtual void CheckPoint(const ISerializable *global_model,
|
virtual void CheckPoint(const ISerializable *global_model,
|
||||||
const ISerializable *local_model) {
|
const ISerializable *local_model) {
|
||||||
this->Verify(MockKey(rank, version_number, seq_counter, num_trial));
|
this->Verify(MockKey(rank, version_number, seq_counter, num_trial), "CheckPoint");
|
||||||
AllreduceRobust::CheckPoint(global_model, local_model);
|
AllreduceRobust::CheckPoint(global_model, local_model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,10 +82,10 @@ class AllreduceMock : public AllreduceRobust {
|
|||||||
// record all mock actions
|
// record all mock actions
|
||||||
std::map<MockKey, int> mock_map;
|
std::map<MockKey, int> mock_map;
|
||||||
// used to generate all kinds of exceptions
|
// used to generate all kinds of exceptions
|
||||||
inline void Verify(const MockKey &key) {
|
inline void Verify(const MockKey &key, const char *name) {
|
||||||
if (mock_map.count(key) != 0) {
|
if (mock_map.count(key) != 0) {
|
||||||
num_trial += 1;
|
num_trial += 1;
|
||||||
utils::Error("[%d]@@@Hit Mock Error", rank);
|
utils::Error("[%d]@@@Hit Mock Error:%s", rank, name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -431,7 +431,7 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
|
|||||||
if (dist_in[i].first + 1 < res) {
|
if (dist_in[i].first + 1 < res) {
|
||||||
res = dist_in[i].first + 1;
|
res = dist_in[i].first + 1;
|
||||||
size = dist_in[i].second;
|
size = dist_in[i].second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// add one hop
|
// add one hop
|
||||||
|
|
||||||
@ -575,7 +575,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
|
|||||||
}
|
}
|
||||||
if (req_in[i] && links[i].size_write != size) {
|
if (req_in[i] && links[i].size_write != size) {
|
||||||
if (role == kHaveData ||
|
if (role == kHaveData ||
|
||||||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
|
(links[recv_link].size_read != links[i].size_write)) {
|
||||||
selecter.WatchWrite(links[i].sock);
|
selecter.WatchWrite(links[i].sock);
|
||||||
}
|
}
|
||||||
finished = false;
|
finished = false;
|
||||||
@ -728,10 +728,17 @@ AllreduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
|||||||
}
|
}
|
||||||
int recv_link;
|
int recv_link;
|
||||||
std::vector<bool> req_in;
|
std::vector<bool> req_in;
|
||||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
// size of data
|
||||||
|
size_t data_size = size;
|
||||||
|
ReturnType succ = TryDecideRouting(role, &data_size, &recv_link, &req_in);
|
||||||
if (succ != kSuccess) return succ;
|
if (succ != kSuccess) return succ;
|
||||||
utils::Check(size != 0, "zero size check point is not allowed");
|
utils::Check(data_size != 0, "zero size check point is not allowed");
|
||||||
return TryRecoverData(role, sendrecvbuf, size, recv_link, req_in);
|
if (role == kRequestData || role == kHaveData) {
|
||||||
|
utils::Check(data_size == size,
|
||||||
|
"Allreduce Recovered data size do not match the specification of function call\n"\
|
||||||
|
"Please check if calling sequence of recovered program is the same the original one in current VersionNumber");
|
||||||
|
}
|
||||||
|
return TryRecoverData(role, sendrecvbuf, data_size, recv_link, req_in);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief try to run recover execution for a request action described by flag and seqno,
|
* \brief try to run recover execution for a request action described by flag and seqno,
|
||||||
|
|||||||
10
src/socket.h
10
src/socket.h
@ -219,6 +219,16 @@ class TCPSocket : public Socket{
|
|||||||
}
|
}
|
||||||
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
explicit TCPSocket(SOCKET sockfd) : Socket(sockfd) {
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief enable/disable TCP keepalive
|
||||||
|
* \param keepalive whether to set the keep alive option on
|
||||||
|
*/
|
||||||
|
inline void SetKeepAlive(bool keepalive) {
|
||||||
|
int opt = static_cast<int>(keepalive);
|
||||||
|
if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &opt, sizeof(opt)) < 0) {
|
||||||
|
Socket::Error("SetKeepAlive");
|
||||||
|
}
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief create the socket, call this before using socket
|
* \brief create the socket, call this before using socket
|
||||||
* \param af domain
|
* \param af domain
|
||||||
|
|||||||
@ -151,4 +151,3 @@ int main(int argc, char *argv[]) {
|
|||||||
rabit::Finalize();
|
rabit::Finalize();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user