return values in Init and Finalize (#96)
* make inti function return values * address the comments
This commit is contained in:
parent
fc85f776f4
commit
65b718a5e7
@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
|
||||
* from environment variables.
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||
|
||||
/*!
|
||||
* \brief finalize the rabit engine,
|
||||
* call this function after you finished all jobs.
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL void RabitFinalize(void);
|
||||
RABIT_DLL bool RabitFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
|
||||
@ -161,9 +161,9 @@ class IEngine {
|
||||
};
|
||||
|
||||
/*! \brief initializes the engine module */
|
||||
void Init(int argc, char *argv[]);
|
||||
bool Init(int argc, char *argv[]);
|
||||
/*! \brief finalizes the engine module */
|
||||
void Finalize(void);
|
||||
bool Finalize(void);
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void);
|
||||
|
||||
|
||||
@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
|
||||
} // namespace op
|
||||
|
||||
// intialize the rabit engine
|
||||
inline void Init(int argc, char *argv[]) {
|
||||
engine::Init(argc, argv);
|
||||
inline bool Init(int argc, char *argv[]) {
|
||||
return engine::Init(argc, argv);
|
||||
}
|
||||
// finalize the rabit engine
|
||||
inline void Finalize(void) {
|
||||
engine::Finalize();
|
||||
inline bool Finalize(void) {
|
||||
return engine::Finalize();
|
||||
}
|
||||
// get the rank of current process
|
||||
inline int GetRank(void) {
|
||||
|
||||
@ -73,12 +73,14 @@ struct BitOR;
|
||||
* \brief initializes rabit, call this once at the beginning of your program
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if initialized successfully, otherwise false
|
||||
*/
|
||||
inline void Init(int argc, char *argv[]);
|
||||
inline bool Init(int argc, char *argv[]);
|
||||
/*!
|
||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||
* \return true if finalized successfully, otherwise false
|
||||
*/
|
||||
inline void Finalize();
|
||||
inline bool Finalize();
|
||||
/*! \brief gets rank of the current process
|
||||
* \return rank number of worker*/
|
||||
inline int GetRank();
|
||||
|
||||
@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
}
|
||||
|
||||
// initialization function
|
||||
void AllreduceBase::Init(int argc, char* argv[]) {
|
||||
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
// setup from enviroment variables
|
||||
// handler to get variables from env
|
||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||
@ -122,24 +122,30 @@ void AllreduceBase::Init(int argc, char* argv[]) {
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
// get information from tracker
|
||||
this->ReConnectLinks();
|
||||
return this->ReConnectLinks();
|
||||
}
|
||||
|
||||
void AllreduceBase::Shutdown(void) {
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
bool AllreduceBase::Shutdown(void) {
|
||||
try {
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
if (tracker_uri == "NULL") return true;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
tracker.Close();
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
@ -252,167 +258,179 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
rank = 0; world_size = 1; return true;
|
||||
}
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
try {
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
using utils::Assert;
|
||||
// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
// the rank of previous link, next link in ring
|
||||
int prev_rank, next_rank;
|
||||
// the rank of neighbors
|
||||
std::map<int, int> tree_neighbors;
|
||||
using utils::Assert;
|
||||
// get new ranks
|
||||
int newrank, num_neighbors;
|
||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
|
||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
|
||||
sizeof(parent_rank), "ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
|
||||
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
|
||||
if (!sock_listen.IsClosed()) {
|
||||
sock_listen.Close();
|
||||
Assert(rank == -1 || newrank == rank,
|
||||
"must keep rank to same if the node already have one");
|
||||
rank = newrank;
|
||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||
for (int i = 0; i < num_neighbors; ++i) {
|
||||
int nrank;
|
||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||
"ReConnectLink failure 4");
|
||||
tree_neighbors[nrank] = 1;
|
||||
}
|
||||
// create listening socket
|
||||
sock_listen.Create();
|
||||
sock_listen.SetKeepAlive(true);
|
||||
// http://deepix.github.io/2016/10/21/tcprst.html
|
||||
sock_listen.SetLinger(0);
|
||||
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
|
||||
// work around processes bind to same port without set reuse option,
|
||||
// start explore from slave_port + newrank towards end
|
||||
port = sock_listen.TryBindHost(slave_port+ newrank%nport_trial, slave_port + nport_trial);
|
||||
// if no port bindable, explore first half of range
|
||||
if (port == -1) sock_listen.TryBindHost(slave_port, newrank% nport_trial + slave_port);
|
||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||
"ReConnectLink failure 4");
|
||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||
"ReConnectLink failure 4");
|
||||
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
}
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
// send over good links
|
||||
std::vector<int> good_link;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||
} else {
|
||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
|
||||
if (!sock_listen.IsClosed()) {
|
||||
sock_listen.Close();
|
||||
}
|
||||
// create listening socket
|
||||
sock_listen.Create();
|
||||
sock_listen.SetKeepAlive(true);
|
||||
// http://deepix.github.io/2016/10/21/tcprst.html
|
||||
sock_listen.SetLinger(0);
|
||||
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
|
||||
// work around processes bind to same port without set reuse option,
|
||||
// start explore from slave_port + newrank towards end
|
||||
port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial);
|
||||
// if no port bindable, explore first half of range
|
||||
if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port);
|
||||
|
||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||
sock_listen.Listen();
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||
|
||||
// get number of to connect and number of to accept nodes from tracker
|
||||
int num_conn, num_accept, num_error = 1;
|
||||
do {
|
||||
// send over good links
|
||||
std::vector<int> good_link;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (!all_links[i].sock.BadSocket()) {
|
||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||
} else {
|
||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||
}
|
||||
}
|
||||
int ngood = static_cast<int>(good_link.size());
|
||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||
"ReConnectLink failure 5");
|
||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
||||
}
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||
sizeof(num_accept), "ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.RecvStr(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||
"ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||
"ReConnectLink failure 10");
|
||||
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
num_error += 1; r.sock.Close(); continue;
|
||||
}
|
||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||
"ReConnectLink failure 7");
|
||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||
sizeof(num_accept), "ReConnectLink failure 8");
|
||||
num_error = 0;
|
||||
for (int i = 0; i < num_conn; ++i) {
|
||||
LinkRecord r;
|
||||
int hport, hrank;
|
||||
std::string hname;
|
||||
tracker.RecvStr(&hname);
|
||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||
"ReConnectLink failure 9");
|
||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||
"ReConnectLink failure 10");
|
||||
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||
"ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
r.sock = sock_listen.Accept();
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
"ReConnectLink failure 15");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 13");
|
||||
utils::Check(hrank == r.rank,
|
||||
"ReConnectLink failure, link rank inconsistent");
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||
"ReConnectLink failure 14");
|
||||
} while (num_error != 0);
|
||||
// send back socket listening port to tracker
|
||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||
"ReConnectLink failure 14");
|
||||
// close connection to tracker
|
||||
tracker.Close();
|
||||
// listen to incoming links
|
||||
for (int i = 0; i < num_accept; ++i) {
|
||||
LinkRecord r;
|
||||
r.sock = sock_listen.Accept();
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 15");
|
||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||
"ReConnectLink failure 15");
|
||||
bool match = false;
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
}
|
||||
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
this->parent_index = -1;
|
||||
// setup tree links and ring structure
|
||||
tree_links.plinks.clear();
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||
// set the socket to non-blocking mode, enable TCP keepalive
|
||||
all_links[i].sock.SetNonBlock(true);
|
||||
all_links[i].sock.SetKeepAlive(true);
|
||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||
if (all_links[i].rank == parent_rank) {
|
||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||
}
|
||||
tree_links.plinks.push_back(&all_links[i]);
|
||||
}
|
||||
tree_links.plinks.push_back(&all_links[i]);
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
}
|
||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
"cannot find next ring in the link");
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
Assert(parent_rank == -1 || parent_index != -1,
|
||||
"cannot find parent in the link");
|
||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
"cannot find next ring in the link");
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
|
||||
@ -38,9 +38,9 @@ class AllreduceBase : public IEngine {
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
// shutdown the engine
|
||||
virtual void Shutdown(void);
|
||||
virtual bool Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@ -369,7 +369,7 @@ class AllreduceBase : public IEngine {
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
void ReConnectLinks(const char *cmd = "start");
|
||||
bool ReConnectLinks(const char *cmd = "start");
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
|
||||
@ -31,31 +31,41 @@ AllreduceRobust::AllreduceRobust(void) {
|
||||
env_vars.push_back("rabit_global_replica");
|
||||
env_vars.push_back("rabit_local_replica");
|
||||
}
|
||||
void AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
AllreduceBase::Init(argc, argv);
|
||||
if (num_global_replica == 0) {
|
||||
result_buffer_round = -1;
|
||||
bool AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
if (AllreduceBase::Init(argc, argv)) {
|
||||
if (num_global_replica == 0) {
|
||||
result_buffer_round = -1;
|
||||
} else {
|
||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
void AllreduceRobust::Shutdown(void) {
|
||||
// need to sync the exec before we shutdown, do a pesudo check point
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check ack must return true");
|
||||
bool AllreduceRobust::Shutdown(void) {
|
||||
try {
|
||||
// need to sync the exec before we shutdown, do a pesudo check point
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear();
|
||||
seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check ack must return true");
|
||||
|
||||
#if defined (__APPLE__)
|
||||
sleep(1);
|
||||
sleep(1);
|
||||
#endif
|
||||
|
||||
AllreduceBase::Shutdown();
|
||||
return AllreduceBase::Shutdown();
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "%s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
|
||||
@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
/*! \brief shutdown the engine */
|
||||
virtual void Shutdown(void);
|
||||
virtual bool Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
|
||||
@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable {
|
||||
} // namespace c_api
|
||||
} // namespace rabit
|
||||
|
||||
void RabitInit(int argc, char *argv[]) {
|
||||
rabit::Init(argc, argv);
|
||||
bool RabitInit(int argc, char *argv[]) {
|
||||
return rabit::Init(argc, argv);
|
||||
}
|
||||
|
||||
void RabitFinalize() {
|
||||
rabit::Finalize();
|
||||
bool RabitFinalize() {
|
||||
return rabit::Finalize();
|
||||
}
|
||||
|
||||
int RabitGetRank() {
|
||||
|
||||
@ -43,23 +43,29 @@ struct ThreadLocalEntry {
|
||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
bool Init(int argc, char *argv[]) {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
if (e->engine.get() == nullptr) {
|
||||
e->initialized = true;
|
||||
e->engine.reset(new Manager());
|
||||
e->engine->Init(argc, argv);
|
||||
return e->engine->Init(argc, argv);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize() {
|
||||
bool Finalize() {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
utils::Check(e->engine.get() != nullptr,
|
||||
"rabit::Finalize engine is not initialized or already been finalized.");
|
||||
e->engine->Shutdown();
|
||||
e->engine.reset(nullptr);
|
||||
e->initialized = false;
|
||||
if (e->engine->Shutdown()) {
|
||||
e->engine.reset(nullptr);
|
||||
e->initialized = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
@ -82,10 +82,12 @@ class EmptyEngine : public IEngine {
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
bool Init(int argc, char *argv[]) {
|
||||
return true;
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
bool Finalize(void) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
@ -90,13 +90,25 @@ class MPIEngine : public IEngine {
|
||||
// singleton sync manager
|
||||
MPIEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
MPI::Init(argc, argv);
|
||||
/*! \brief initialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
try {
|
||||
MPI::Init(argc, argv);
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, " failed in MPI Init %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
MPI::Finalize();
|
||||
bool Finalize(void) {
|
||||
try {
|
||||
MPI::Finalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user