fix one bug, another comes

This commit is contained in:
tqchen 2014-12-01 19:53:41 -08:00
parent 993ff8bb91
commit 46b5d46111
5 changed files with 39 additions and 13 deletions

View File

@ -111,7 +111,11 @@ void AllReduceBase::Shutdown(void) {
links.clear(); links.clear();
utils::TCPSocket::Finalize(); utils::TCPSocket::Finalize();
} }
// set the parameters for AllReduce /*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
void AllReduceBase::SetParam(const char *name, const char *val) { void AllReduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "master_uri")) master_uri = val; if (!strcmp(name, "master_uri")) master_uri = val;
if (!strcmp(name, "master_port")) master_port = atoi(val); if (!strcmp(name, "master_port")) master_port = atoi(val);

View File

@ -39,7 +39,11 @@ class AllReduceBase : public IEngine {
void Shutdown(void); void Shutdown(void);
// initialize the manager // initialize the manager
void Init(void); void Init(void);
/*! \brief set parameters to the sync manager */ /*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val); virtual void SetParam(const char *name, const char *val);
/*! \brief get rank */ /*! \brief get rank */
virtual int GetRank(void) const { virtual int GetRank(void) const {

View File

@ -19,6 +19,13 @@ AllReduceRobust::AllReduceRobust(void) {
result_buffer_round = 2; result_buffer_round = 2;
seq_counter = 0; seq_counter = 0;
} }
void AllReduceRobust::SetParam(const char *name, const char *val) {
AllReduceBase::SetParam(name, val);
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
if (!strcmp(name, "result_replicate")) {
result_buffer_round = std::max(world_size / atoi(val), 1);
}
}
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -32,7 +39,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
size_t count, size_t count,
ReduceFunction reducer) { ReduceFunction reducer) {
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter); bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered); //utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered);
// now we are free to remove the last result, if any // now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 && if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) { (resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
@ -302,12 +309,15 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
int res = std::numeric_limits<int>::max(); int res = std::numeric_limits<int>::max();
for (size_t i = 0; i < dist_in.size(); ++i) { for (size_t i = 0; i < dist_in.size(); ++i) {
if (i == out_index) continue; if (i == out_index) continue;
if (dist_in[i].first < res) { if (dist_in[i].first == std::numeric_limits<int>::max()) continue;
res = dist_in[i].first; size = dist_in[i].second; if (dist_in[i].first + 1 < res) {
res = dist_in[i].first + 1;
size = dist_in[i].second;
} }
} }
// add one hop // add one hop
return std::make_pair(res + 1, size);
return std::make_pair(res, size);
} }
/*! /*!
* \brief message passing function, used to decide the * \brief message passing function, used to decide the
@ -365,7 +375,8 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
for (size_t i = 0; i < dist_in.size(); ++i) { for (size_t i = 0; i < dist_in.size(); ++i) {
if (dist_in[i].first != std::numeric_limits<int>::max()) { if (dist_in[i].first != std::numeric_limits<int>::max()) {
utils::Check(best_link == -2 || *p_size == dist_in[i].second, utils::Check(best_link == -2 || *p_size == dist_in[i].second,
"AllReduce size inconsistent, size=%lu, reporting=%lu", *p_size, dist_in[i].second); "[%d] AllReduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
rank, dist_in[i].first, *p_size, dist_in[i].second);
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) { if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
best_link = static_cast<int>(i); best_link = static_cast<int>(i);
*p_size = dist_in[i].second; *p_size = dist_in[i].second;
@ -416,7 +427,6 @@ AllReduceRobust::TryRecoverData(RecoverType role,
size_t size, size_t size,
int recv_link, int recv_link,
const std::vector<bool> &req_in) { const std::vector<bool> &req_in) {
utils::LogPrintf("[%d] recv_link=%d\n", rank, recv_link);
// no need to run recovery for zero size message // no need to run recovery for zero size message
if (links.size() == 0 || size == 0) return kSuccess; if (links.size() == 0 || size == 0) return kSuccess;
utils::Assert(req_in.size() == links.size(), "TryRecoverData"); utils::Assert(req_in.size() == links.size(), "TryRecoverData");
@ -432,6 +442,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
// do not need to provide data or receive data, directly exit // do not need to provide data or receive data, directly exit
if (!req_data) return kSuccess; if (!req_data) return kSuccess;
} }
utils::LogPrintf("[%d] !!Need to pass data\n", rank);
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active"); utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
for (int i = 0; i < nlink; ++i) { for (int i = 0; i < nlink; ++i) {
links[i].ResetSize(); links[i].ResetSize();
@ -548,6 +559,7 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
} else { } else {
role = kRequestData; role = kRequestData;
} }
utils::LogPrintf("[%d] role=%d\n", rank, role);
int recv_link; int recv_link;
std::vector<bool> req_in; std::vector<bool> req_in;
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in); ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);

View File

@ -20,6 +20,12 @@ class AllReduceRobust : public AllReduceBase {
public: public:
AllReduceRobust(void); AllReduceRobust(void);
virtual ~AllReduceRobust(void) {} virtual ~AllReduceRobust(void) {}
/*!
* \brief set parameters to the engine
* \param name parameter name
* \param val parameter value
*/
virtual void SetParam(const char *name, const char *val);
/*! /*!
* \brief perform in-place allreduce, on sendrecvbuf * \brief perform in-place allreduce, on sendrecvbuf
* this function is NOT thread-safe * this function is NOT thread-safe
@ -71,11 +77,11 @@ class AllReduceRobust : public AllReduceBase {
/*! \brief type of roles each node can play during recovery */ /*! \brief type of roles each node can play during recovery */
enum RecoverType { enum RecoverType {
/*! \brief current node have data */ /*! \brief current node have data */
kHaveData, kHaveData = 0,
/*! \brief current node request data */ /*! \brief current node request data */
kRequestData, kRequestData = 1,
/*! \brief current node only helps to pass data around */ /*! \brief current node only helps to pass data around */
kPassData kPassData = 2
}; };
/*! /*!
* \brief summary of actions proposed in all nodes * \brief summary of actions proposed in all nodes

View File

@ -1,8 +1,8 @@
#!/bin/bash #!/bin/bash
if [ "$#" -ne 4 ]; if [ "$#" -lt 4 ];
then then
echo "Usage <nslave> <ndata> <config> <round_files_dir>" echo "Usage <nslave> <ndata> <config> <round_files_dir>"
exit -1 exit -1
fi fi
../submit_job_tcp.py $1 test_recover $2 $3 $4 $5 ../submit_job_tcp.py $1 test_recover "${@:2}"