fix one bug, another comes
This commit is contained in:
parent
993ff8bb91
commit
46b5d46111
@ -111,7 +111,11 @@ void AllReduceBase::Shutdown(void) {
|
||||
links.clear();
|
||||
utils::TCPSocket::Finalize();
|
||||
}
|
||||
// set the parameters for AllReduce
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
void AllReduceBase::SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "master_uri")) master_uri = val;
|
||||
if (!strcmp(name, "master_port")) master_port = atoi(val);
|
||||
|
||||
@ -39,7 +39,11 @@ class AllReduceBase : public IEngine {
|
||||
void Shutdown(void);
|
||||
// initialize the manager
|
||||
void Init(void);
|
||||
/*! \brief set parameters to the sync manager */
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*! \brief get rank */
|
||||
virtual int GetRank(void) const {
|
||||
|
||||
@ -19,6 +19,13 @@ AllReduceRobust::AllReduceRobust(void) {
|
||||
result_buffer_round = 2;
|
||||
seq_counter = 0;
|
||||
}
|
||||
void AllReduceRobust::SetParam(const char *name, const char *val) {
|
||||
AllReduceBase::SetParam(name, val);
|
||||
if (!strcmp(name, "result_buffer_round")) result_buffer_round = atoi(val);
|
||||
if (!strcmp(name, "result_replicate")) {
|
||||
result_buffer_round = std::max(world_size / atoi(val), 1);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -32,7 +39,7 @@ void AllReduceRobust::AllReduce(void *sendrecvbuf_,
|
||||
size_t count,
|
||||
ReduceFunction reducer) {
|
||||
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
|
||||
utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered);
|
||||
//utils::LogPrintf("[%d] AllReduce recovered=%d\n", rank, recovered);
|
||||
// now we are free to remove the last result, if any
|
||||
if (resbuf.LastSeqNo() != -1 &&
|
||||
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
|
||||
@ -302,12 +309,15 @@ ShortestDist(const std::pair<bool, size_t> &node_value,
|
||||
int res = std::numeric_limits<int>::max();
|
||||
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||
if (i == out_index) continue;
|
||||
if (dist_in[i].first < res) {
|
||||
res = dist_in[i].first; size = dist_in[i].second;
|
||||
if (dist_in[i].first == std::numeric_limits<int>::max()) continue;
|
||||
if (dist_in[i].first + 1 < res) {
|
||||
res = dist_in[i].first + 1;
|
||||
size = dist_in[i].second;
|
||||
}
|
||||
}
|
||||
// add one hop
|
||||
return std::make_pair(res + 1, size);
|
||||
|
||||
return std::make_pair(res, size);
|
||||
}
|
||||
/*!
|
||||
* \brief message passing function, used to decide the
|
||||
@ -365,7 +375,8 @@ AllReduceRobust::TryDecideRouting(AllReduceRobust::RecoverType role,
|
||||
for (size_t i = 0; i < dist_in.size(); ++i) {
|
||||
if (dist_in[i].first != std::numeric_limits<int>::max()) {
|
||||
utils::Check(best_link == -2 || *p_size == dist_in[i].second,
|
||||
"AllReduce size inconsistent, size=%lu, reporting=%lu", *p_size, dist_in[i].second);
|
||||
"[%d] AllReduce size inconsistent, distin=%lu, size=%lu, reporting=%lu\n",
|
||||
rank, dist_in[i].first, *p_size, dist_in[i].second);
|
||||
if (best_link == -2 || dist_in[i].first < dist_in[best_link].first) {
|
||||
best_link = static_cast<int>(i);
|
||||
*p_size = dist_in[i].second;
|
||||
@ -416,7 +427,6 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
||||
size_t size,
|
||||
int recv_link,
|
||||
const std::vector<bool> &req_in) {
|
||||
utils::LogPrintf("[%d] recv_link=%d\n", rank, recv_link);
|
||||
// no need to run recovery for zero size message
|
||||
if (links.size() == 0 || size == 0) return kSuccess;
|
||||
utils::Assert(req_in.size() == links.size(), "TryRecoverData");
|
||||
@ -432,6 +442,7 @@ AllReduceRobust::TryRecoverData(RecoverType role,
|
||||
// do not need to provide data or receive data, directly exit
|
||||
if (!req_data) return kSuccess;
|
||||
}
|
||||
utils::LogPrintf("[%d] !!Need to pass data\n", rank);
|
||||
utils::Assert(recv_link >= 0 || role == kHaveData, "recv_link must be active");
|
||||
for (int i = 0; i < nlink; ++i) {
|
||||
links[i].ResetSize();
|
||||
@ -548,6 +559,7 @@ AllReduceRobust::TryGetResult(void *sendrecvbuf, size_t size, int seqno, bool re
|
||||
} else {
|
||||
role = kRequestData;
|
||||
}
|
||||
utils::LogPrintf("[%d] role=%d\n", rank, role);
|
||||
int recv_link;
|
||||
std::vector<bool> req_in;
|
||||
ReturnType succ = TryDecideRouting(role, &size, &recv_link, &req_in);
|
||||
|
||||
@ -20,6 +20,12 @@ class AllReduceRobust : public AllReduceBase {
|
||||
public:
|
||||
AllReduceRobust(void);
|
||||
virtual ~AllReduceRobust(void) {}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
* \param val parameter value
|
||||
*/
|
||||
virtual void SetParam(const char *name, const char *val);
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf
|
||||
* this function is NOT thread-safe
|
||||
@ -71,11 +77,11 @@ class AllReduceRobust : public AllReduceBase {
|
||||
/*! \brief type of roles each node can play during recovery */
|
||||
enum RecoverType {
|
||||
/*! \brief current node have data */
|
||||
kHaveData,
|
||||
kHaveData = 0,
|
||||
/*! \brief current node request data */
|
||||
kRequestData,
|
||||
kRequestData = 1,
|
||||
/*! \brief current node only helps to pass data around */
|
||||
kPassData
|
||||
kPassData = 2
|
||||
};
|
||||
/*!
|
||||
* \brief summary of actions proposed in all nodes
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
#!/bin/bash
|
||||
if [ "$#" -ne 4 ];
|
||||
if [ "$#" -lt 4 ];
|
||||
then
|
||||
echo "Usage <nslave> <ndata> <config> <round_files_dir>"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
../submit_job_tcp.py $1 test_recover $2 $3 $4 $5
|
||||
../submit_job_tcp.py $1 test_recover "${@:2}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user