diff --git a/src/engine_base.cc b/src/engine_base.cc index 4e9a65229..3b08d1502 100644 --- a/src/engine_base.cc +++ b/src/engine_base.cc @@ -111,7 +111,11 @@ void AllReduceBase::Shutdown(void) { links.clear(); utils::TCPSocket::Finalize(); } -// set the parameters for AllReduce +/*! + * \brief set parameters to the engine + * \param name parameter name + * \param val parameter value + */ void AllReduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "master_uri")) master_uri = val; if (!strcmp(name, "master_port")) master_port = atoi(val); diff --git a/src/engine_base.h b/src/engine_base.h index 0cc281cff..9e533fe27 100644 --- a/src/engine_base.h +++ b/src/engine_base.h @@ -39,7 +39,11 @@ class AllReduceBase : public IEngine { void Shutdown(void); // initialize the manager void Init(void); - /*! \brief set parameters to the sync manager */ + /*! + * \brief set parameters to the engine + * \param name parameter name + * \param val parameter value + */ virtual void SetParam(const char *name, const char *val); /*! \brief get rank */ virtual int GetRank(void) const { diff --git a/src/engine_robust.cc b/src/engine_robust.cc index 099e83934..9db11bc96 100644 --- a/src/engine_robust.cc +++ b/src/engine_robust.cc @@ -19,6 +19,13 @@ AllReduceRobust::AllReduceRobust(void) { result_buffer_round = 2; seq_counter = 0; } +void AllReduceRobust::SetParam(const char *name, const char *val) { + AllReduceBase::SetParam(name, val); + if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val); + if (!strcmp(name, "result_replicate")) { + result_buffer_round = std::max(world_size / atoi(val), 1); + } +} /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -32,7 +39,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_, size_t count, ReduceFunction reducer) { bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); - utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered); + //utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered); // now we are free to remove the last result, if any if (resbuf.LastSeqNo() != -1 && (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { @@ -302,12 +309,15 @@ ShortestDist(const std::pair &node_value, int res = std::numeric_limits::max(); for (size_t i = 0; i < dist_in.size(); ++i) { if (i == out_index) continue; - if (dist_in[i].first < res) { - res = dist_in[i].first; size = dist_in[i].second; + if (dist_in[i].first == std::numeric_limits::max()) continue; + if (dist_in[i].first + 1 < res) { + res = dist_in[i].first + 1; + size = dist_in[i].second; } } // add one hop - return std::make_pair(res + 1, size); + + return std::make_pair(res, size); } /*! * \brief message passing function, used to decide the @@ -365,7 +375,8 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role, for (size_t i = 0; i < dist_in.size(); ++i) { if (dist_in[i].first != std::numeric_limits::max()) { utils::Check(best_link == -2 || *p_size == dist_in[i].second, - "AllReduce size inconsistent, size=%lu, reporting=%lu", *p_size, dist_in[i].second); + "[%d] AllReduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n", + rank, dist_in[i].first, *p_size, dist_in[i].second); if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) { best_link = static_cast(i); *p_size = dist_in[i].second; @@ -416,7 +427,6 @@ AllReduceRobust::TryRecoverData(RecoverType role, size_t size, int recv_link, const std::vector &req_in) { - utils::LogPrintf("[%d] recv_link=%d\n", rank, recv_link); // no need to run recovery for zero size message if (links.size() == 0 || size == 0) return kSuccess; utils::Assert(req_in.size() == links.size(), "TryRecoverData"); @@ -432,6 +442,7 @@ AllReduceRobust::TryRecoverData(RecoverType role, // do not need to provide data or receive data, directly exit if (!req_data) return kSuccess; } + utils::LogPrintf("[%d] !!Need to pass data\n", rank); utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); for (int i = 0; i < nlink; ++i) { links[i].ResetSize(); @@ -548,6 +559,7 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re } else { role = kRequestData; } + utils::LogPrintf("[%d] role=%d\n", rank, role); int recv_link; std::vector req_in; ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); diff --git a/src/engine_robust.h b/src/engine_robust.h index 783b2deb7..7116764d8 100644 --- a/src/engine_robust.h +++ b/src/engine_robust.h @@ -20,6 +20,12 @@ class AllReduceRobust : public AllReduceBase { public: AllReduceRobust(void); virtual ~AllReduceRobust(void) {} + /*! + * \brief set parameters to the engine + * \param name parameter name + * \param val parameter value + */ + virtual void SetParam(const char *name, const char *val); /*! * \brief perform in-place allreduce, on sendrecvbuf * this function is NOT thread-safe @@ -71,11 +77,11 @@ class AllReduceRobust : public AllReduceBase { /*! \brief type of roles each node can play during recovery */ enum RecoverType { /*! \brief current node have data */ - kHaveData, + kHaveData = 0, /*! \brief current node request data */ - kRequestData, + kRequestData = 1, /*! \brief current node only helps to pass data around */ - kPassData + kPassData = 2 }; /*! * \brief summary of actions proposed in all nodes diff --git a/test/test.sh b/test/test.sh index 9b27abb8b..78d267157 100755 --- a/test/test.sh +++ b/test/test.sh @@ -1,8 +1,8 @@ #!/bin/bash -if [ "$#" -ne 4 ]; +if [ "$#" -lt 4 ]; then echo "Usage " exit -1 fi -../submit_job_tcp.py $1 test_recover $2 $3 $4 $5 +../submit_job_tcp.py $1 test_recover "${@:2}"