finish reset link log
This commit is contained in:
parent
98756c068a
commit
42505f473d
@ -35,8 +35,10 @@ class AllReduceManager : public IEngine {
|
|||||||
// constant one byte out of band message to indicate error happening
|
// constant one byte out of band message to indicate error happening
|
||||||
// and mark for channel cleanup
|
// and mark for channel cleanup
|
||||||
const static char kOOBReset = 95;
|
const static char kOOBReset = 95;
|
||||||
|
// and mark for channel cleanup, after OOB signal
|
||||||
|
const static char kResetMark = 97;
|
||||||
// and mark for channel cleanup
|
// and mark for channel cleanup
|
||||||
const static char kOOBResetAck = 97;
|
const static char kResetAck = 97;
|
||||||
|
|
||||||
AllReduceManager(void) {
|
AllReduceManager(void) {
|
||||||
master_uri = "NULL";
|
master_uri = "NULL";
|
||||||
@ -173,7 +175,6 @@ class AllReduceManager : public IEngine {
|
|||||||
size_t count,
|
size_t count,
|
||||||
ReduceFunction reducer) {
|
ReduceFunction reducer) {
|
||||||
while (true) {
|
while (true) {
|
||||||
if (rank == rand() % 3) TryResetLinks();
|
|
||||||
ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer);
|
ReturnType ret = TryAllReduce(sendrecvbuf_, type_nbytes, count, reducer);
|
||||||
if (ret == kSuccess) return;
|
if (ret == kSuccess) return;
|
||||||
if (ret == kSockError) {
|
if (ret == kSockError) {
|
||||||
@ -280,116 +281,92 @@ class AllReduceManager : public IEngine {
|
|||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
links[i].InitBuffer(sizeof(int), 1 << 10, reduce_buffer_size);
|
||||||
links[i].ResetSize();
|
links[i].ResetSize();
|
||||||
|
links[i].except = false;
|
||||||
}
|
}
|
||||||
printf("[%d] start to reset link\n", rank);
|
// read and discard data from all channels until pass mark
|
||||||
while (true) {
|
while (true) {
|
||||||
printf("[%d] loop\n", rank);
|
|
||||||
bool finished = true;
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (links[i].sock.BadSocket()) continue;
|
||||||
if (links[i].size_write == 0) {
|
if (links[i].size_write == 0) {
|
||||||
char sig = kOOBReset;
|
char sig = kOOBReset;
|
||||||
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig), MSG_OOB);
|
||||||
// error will be filtered in next loop
|
// error will be filtered in next loop
|
||||||
if (len != -1) {
|
if (len == sizeof(sig)) links[i].size_write = 1;
|
||||||
links[i].size_write += len;
|
}
|
||||||
printf("[%d] send OOB success\n", rank);
|
if (links[i].size_write == 1) {
|
||||||
|
char sig = kResetMark;
|
||||||
|
ssize_t len = links[i].sock.Send(&sig, sizeof(sig));
|
||||||
|
if (len == sizeof(sig)) links[i].size_write = 2;
|
||||||
|
}
|
||||||
|
if (links[i].size_read == 0) {
|
||||||
|
int atmark = links[i].sock.AtMark();
|
||||||
|
if (atmark < 0) {
|
||||||
|
utils::Assert(links[i].sock.BadSocket(), "must already gone bad");
|
||||||
|
} else if (atmark > 0) {
|
||||||
|
links[i].size_read = 1;
|
||||||
|
} else {
|
||||||
|
// no at mark, read and discard data
|
||||||
|
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
||||||
|
// zero length, remote closed the connection, close socket
|
||||||
|
if (len == 0) links[i].sock.Close();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// need to send OOB to every other link
|
|
||||||
if (links[i].size_write == 0) finished = false;
|
|
||||||
}
|
}
|
||||||
if (finished) break;
|
|
||||||
}
|
|
||||||
printf("[%d] finish send all OOB\n", rank);
|
|
||||||
// wait for incoming except from all links
|
|
||||||
for (int i = 0; i < nlink; ++ i) {
|
|
||||||
if (links[i].sock.BadSocket()) continue;
|
|
||||||
printf("[%d] wait except\n", rank);
|
|
||||||
if (utils::SelectHelper::WaitExcept(links[i].sock) == -1) {
|
|
||||||
utils::Socket::Error("select");
|
|
||||||
}
|
|
||||||
printf("[%d] finish wait except\n", rank);
|
|
||||||
}
|
|
||||||
printf("[%d] start to discard link\n", rank);
|
|
||||||
// read and discard data from all channels until pass mark
|
|
||||||
while (true) {
|
|
||||||
utils::SelectHelper rsel;
|
utils::SelectHelper rsel;
|
||||||
bool finished = true;
|
bool finished = true;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (links[i].size_write != 2 && !links[i].sock.BadSocket()) {
|
||||||
if (links[i].size_read == 0) {
|
rsel.WatchWrite(links[i].sock); finished = false;
|
||||||
int atmark = links[i].sock.AtMark();
|
|
||||||
if (atmark < 0) return kSockError;
|
|
||||||
if (atmark == 1) {
|
|
||||||
char oob_msg;
|
|
||||||
ssize_t len = links[i].sock.Recv(&oob_msg, sizeof(oob_msg), MSG_OOB);
|
|
||||||
if (len == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
|
|
||||||
finished = false; continue;
|
|
||||||
}
|
|
||||||
utils::Assert(oob_msg == kOOBReset, "wrong oob msg");
|
|
||||||
links[i].size_read = 1;
|
|
||||||
} else {
|
|
||||||
finished = false;
|
|
||||||
rsel.WatchRead(links[i].sock);
|
|
||||||
}
|
}
|
||||||
|
if (links[i].size_read == 0 && !links[i].sock.BadSocket()) {
|
||||||
|
rsel.WatchRead(links[i].sock); finished = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (finished) break;
|
if (finished) break;
|
||||||
// wait to read from the channels to discard data
|
// wait to read from the channels to discard data
|
||||||
rsel.Select();
|
rsel.Select();
|
||||||
printf("[%d] select finish read from\n", rank);
|
}
|
||||||
|
// start synchronization, use blocking I/O to avoid select
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (!links[i].sock.BadSocket()) {
|
||||||
if (rsel.CheckRead(links[i].sock)) {
|
char oob_mark;
|
||||||
ssize_t len = links[i].sock.Recv(links[i].buffer_head, links[i].buffer_size);
|
links[i].sock.SetNonBlock(false);
|
||||||
// zero length, remote closed the connection, close socket
|
ssize_t len = links[i].sock.Recv(&oob_mark, sizeof(oob_mark), MSG_WAITALL);
|
||||||
if (len == 0) {
|
if (len == 0) {
|
||||||
links[i].sock.Close();
|
links[i].sock.Close(); continue;
|
||||||
} else if (len == -1) {
|
} else if (len > 0) {
|
||||||
// when error happens here, oob_clear will remember
|
utils::Assert(oob_mark == kResetMark, "wrong oob msg");
|
||||||
if (errno == EAGAIN && errno == EWOULDBLOCK) printf("would block\n");
|
utils::Assert(!links[i].sock.AtMark(), "should already read past mark");
|
||||||
} else {
|
} else {
|
||||||
printf("[%d] discard %ld bytes\n", rank, len);
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
printf("[%d] discard all success\n", rank);
|
|
||||||
// start synchronization step
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
|
||||||
links[i].ResetSize();
|
|
||||||
}
|
}
|
||||||
|
// send out ack
|
||||||
|
char ack = kResetAck;
|
||||||
while (true) {
|
while (true) {
|
||||||
// selecter for TryResetLinks
|
len = links[i].sock.Send(&ack, sizeof(ack));
|
||||||
utils::SelectHelper rsel;
|
if (len == sizeof(ack)) break;
|
||||||
for (int i = 0; i < nlink; ++i) {
|
if (len == -1) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (errno != EAGAIN && errno != EWOULDBLOCK) break;
|
||||||
if (links[i].size_read == 0) rsel.WatchRead(links[i].sock);
|
|
||||||
if (links[i].size_write == 0) rsel.WatchWrite(links[i].sock);
|
|
||||||
}
|
}
|
||||||
printf("[%d] before select\n", rank);
|
}
|
||||||
rsel.Select();
|
}
|
||||||
printf("[%d] after select\n", rank);
|
}
|
||||||
bool finished = true;
|
// wait all ack
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) continue;
|
if (!links[i].sock.BadSocket()) {
|
||||||
if (links[i].size_read == 0 && rsel.CheckRead(links[i].sock)) {
|
|
||||||
char ack;
|
char ack;
|
||||||
links[i].ReadToArray(&ack, sizeof(ack));
|
ssize_t len = links[i].sock.Recv(&ack, sizeof(ack), MSG_WAITALL);
|
||||||
if (links[i].size_read != 0) {
|
if (len == 0) {
|
||||||
utils::Assert(ack == kOOBResetAck, "expect ack message");
|
links[i].sock.Close(); continue;
|
||||||
|
} else if (len > 0) {
|
||||||
|
utils::Assert(ack == kResetAck, "wrong Ack MSG");
|
||||||
|
} else {
|
||||||
|
utils::Assert(errno != EAGAIN|| errno != EWOULDBLOCK, "BUG");
|
||||||
|
}
|
||||||
|
// set back to nonblock mode
|
||||||
|
links[i].sock.SetNonBlock(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (links[i].size_write == 0 && rsel.CheckWrite(links[i].sock)) {
|
|
||||||
char ack = kOOBResetAck;
|
|
||||||
links[i].WriteFromArray(&ack, sizeof(ack));
|
|
||||||
}
|
|
||||||
if (links[i].size_read == 0 || links[i].size_write == 0) finished = false;
|
|
||||||
}
|
|
||||||
if (finished) break;
|
|
||||||
}
|
|
||||||
printf("[%d] after the read write data success\n", rank);
|
|
||||||
for (int i = 0; i < nlink; ++i) {
|
for (int i = 0; i < nlink; ++i) {
|
||||||
if (links[i].sock.BadSocket()) return kSockError;
|
if (links[i].sock.BadSocket()) return kSockError;
|
||||||
}
|
}
|
||||||
@ -541,8 +518,11 @@ class AllReduceManager : public IEngine {
|
|||||||
char *buffer_head;
|
char *buffer_head;
|
||||||
// buffer size, in bytes
|
// buffer size, in bytes
|
||||||
size_t buffer_size;
|
size_t buffer_size;
|
||||||
|
// exception
|
||||||
|
bool except;
|
||||||
// constructor
|
// constructor
|
||||||
LinkRecord(void) {}
|
LinkRecord(void) {}
|
||||||
|
|
||||||
// initialize buffer
|
// initialize buffer
|
||||||
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
inline void InitBuffer(size_t type_nbytes, size_t count, size_t reduce_buffer_size) {
|
||||||
size_t n = (type_nbytes * count + 7)/ 8;
|
size_t n = (type_nbytes * count + 7)/ 8;
|
||||||
@ -587,7 +567,7 @@ class AllReduceManager : public IEngine {
|
|||||||
* \return true if it is an successful read, false if there is some error happens, check errno
|
* \return true if it is an successful read, false if there is some error happens, check errno
|
||||||
*/
|
*/
|
||||||
inline bool ReadToArray(void *recvbuf_, size_t max_size) {
|
inline bool ReadToArray(void *recvbuf_, size_t max_size) {
|
||||||
if (max_size == size_read ) return true;
|
if (max_size == size_read) return true;
|
||||||
char *p = static_cast<char*>(recvbuf_);
|
char *p = static_cast<char*>(recvbuf_);
|
||||||
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
ssize_t len = sock.Recv(p + size_read, max_size - size_read);
|
||||||
// length equals 0, remote disconnected
|
// length equals 0, remote disconnected
|
||||||
|
|||||||
@ -76,9 +76,9 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
printf("[%d] start at %s\n", rank, name.c_str());
|
printf("[%d] start at %s\n", rank, name.c_str());
|
||||||
TestMax(mock, n);
|
TestMax(mock, n);
|
||||||
printf("[%d] TestMax pass\n", rank);
|
printf("[%d] !!!TestMax pass\n", rank);
|
||||||
TestSum(mock, n);
|
TestSum(mock, n);
|
||||||
printf("[%d] TestSum pass\n", rank);
|
printf("[%d] !!!TestSum pass\n", rank);
|
||||||
sync::Finalize();
|
sync::Finalize();
|
||||||
printf("[%d] all check pass\n", rank);
|
printf("[%d] all check pass\n", rank);
|
||||||
return 0;
|
return 0;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user