seems ok version

This commit is contained in:
tqchen 2014-12-01 20:18:25 -08:00
parent 46b5d46111
commit b76cd5858c
2 changed files with 10 additions and 13 deletions

View File

@ -39,7 +39,6 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
//utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered);
// now we are free to remove the last result, if any // now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 && if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
@ -442,14 +441,10 @@ AllReduceRobust::TryRecoverData(RecoverType role,
// do not need to provide data or receive data, directly exit // do not need to provide data or receive data, directly exit
if (!req_data) return kSuccess; if (!req_data) return kSuccess;
} }
utils::LogPrintf("[%d] !!Need to pass data\n", rank);
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();
} }
if (role == kPassData) {
links[recv_link].InitBuffer(1, size, reduce_buffer_size);
}
while (true) { while (true) {
bool finished = true; bool finished = true;
utils::SelectHelper selecter; utils::SelectHelper selecter;
@ -457,9 +452,12 @@ AllReduceRobust::TryRecoverData(RecoverType role,
if (i == recv_link && links[i].size_read != size) { if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock); selecter.WatchRead(links[i].sock);
finished = false; finished = false;
} }
if (req_in[i] && links[i].size_write != size) { if (req_in[i] && links[i].size_write != size) {
selecter.WatchWrite(links[i].sock); if (role == kHaveData ||
(role == kPassData && links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
}
finished = false; finished = false;
} }
selecter.WatchException(links[i].sock); selecter.WatchException(links[i].sock);
@ -496,12 +494,12 @@ AllReduceRobust::TryRecoverData(RecoverType role,
utils::Assert(min_write <= links[pid].size_read, "boundary check"); utils::Assert(min_write <= links[pid].size_read, "boundary check");
if (!links[pid].ReadToRingBuffer(min_write)) return kSockError; if (!links[pid].ReadToRingBuffer(min_write)) return kSockError;
} }
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
if (req_in[i] && selecter.CheckWrite(links[i].sock)) { if (req_in[i] && selecter.CheckWrite(links[i].sock) && links[pid].size_read != links[i].size_write) {
size_t start = links[i].size_write % buffer_size; size_t start = links[i].size_write % buffer_size;
// send out data from ring buffer // send out data from ring buffer
size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write); size_t nwrite = std::min(buffer_size - start, links[pid].size_read - links[i].size_write);
ssize_t len = links[pid].sock.Send(links[pid].buffer_head + start, nwrite); ssize_t len = links[i].sock.Send(links[pid].buffer_head + start, nwrite);
if (len != -1) { if (len != -1) {
links[i].size_write += len; links[i].size_write += len;
} else { } else {
@ -559,7 +557,6 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
} else { } else {
role = kRequestData; role = kRequestData;
} }
utils::LogPrintf("[%d] role=%d\n", rank, role);
int recv_link; int recv_link;
std::vector<bool> req_in; std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
@ -590,7 +587,6 @@ bool AllReduceRobust::RecoverExec(void *buf, size_t size, int flag, int seqno) {
} }
// request // request
ActionSummary req(flag, seqno); ActionSummary req(flag, seqno);
utils::LogPrintf("[%d] propose flag=%d, seq=%d\n", rank, flag, seqno);
while (true) { while (true) {
// action // action
ActionSummary act = req; ActionSummary act = req;

View File

@ -19,6 +19,7 @@ inline void TestMax(test::Mock &mock, size_t n, int ntrial) {
ndata[i] = (i * (rank+1)) % 111; ndata[i] = (i * (rank+1)) % 111;
} }
mock.AllReduce<op::Max>(&ndata[0], ndata.size()); mock.AllReduce<op::Max>(&ndata[0], ndata.size());
if (ntrial == 0 && rank == 15) throw MockException();
for (size_t i = 0; i < ndata.size(); ++i) { for (size_t i = 0; i < ndata.size(); ++i) {
float rmax = (i * 1) % 111; float rmax = (i * 1) % 111;
for (int r = 0; r < nproc; ++r) { for (int r = 0; r < nproc; ++r) {