return values in Init and Finalize (#96)

* make inti function return values

* address the comments
This commit is contained in:
Nan Zhu 2019-06-25 20:05:54 -07:00 committed by GitHub
parent fc85f776f4
commit 65b718a5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 253 additions and 201 deletions

View File

@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
* from environment variables.
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL void RabitInit(int argc, char *argv[]);
RABIT_DLL bool RabitInit(int argc, char *argv[]);
/*!
* \brief finalize the rabit engine,
* call this function after you finished all jobs.
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL void RabitFinalize(void);
RABIT_DLL bool RabitFinalize(void);
/*!
* \brief get rank of current process

View File

@ -161,9 +161,9 @@ class IEngine {
};
/*! \brief initializes the engine module */
void Init(int argc, char *argv[]);
bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */
void Finalize(void);
bool Finalize(void);
/*! \brief singleton method to get engine */
IEngine *GetEngine(void);

View File

@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
} // namespace op
// intialize the rabit engine
inline void Init(int argc, char *argv[]) {
engine::Init(argc, argv);
inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline void Finalize(void) {
engine::Finalize();
inline bool Finalize(void) {
return engine::Finalize();
}
// get the rank of current process
inline int GetRank(void) {

View File

@ -73,12 +73,14 @@ struct BitOR;
* \brief initializes rabit, call this once at the beginning of your program
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if initialized successfully, otherwise false
*/
inline void Init(int argc, char *argv[]);
inline bool Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
* \return true if finalized successfully, otherwise false
*/
inline void Finalize();
inline bool Finalize();
/*! \brief gets rank of the current process
* \return rank number of worker*/
inline int GetRank();

View File

@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) {
}
// initialization function
void AllreduceBase::Init(int argc, char* argv[]) {
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
@ -122,24 +122,30 @@ void AllreduceBase::Init(int argc, char* argv[]) {
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
this->ReConnectLinks();
return this->ReConnectLinks();
}
void AllreduceBase::Shutdown(void) {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
bool AllreduceBase::Shutdown(void) {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
// close listening sockets
sock_listen.Close();
utils::TCPSocket::Finalize();
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
tracker.Close();
// close listening sockets
sock_listen.Close();
utils::TCPSocket::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
return false;
}
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
@ -252,167 +258,179 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
bool AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
rank = 0; world_size = 1; return true;
}
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
try {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
// the rank of previous link, next link in ring
int prev_rank, next_rank;
// the rank of neighbors
std::map<int, int> tree_neighbors;
using utils::Assert;
// get new ranks
int newrank, num_neighbors;
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
sizeof(parent_rank), "ReConnectLink failure 4");
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
if (!sock_listen.IsClosed()) {
sock_listen.Close();
Assert(rank == -1 || newrank == rank,
"must keep rank to same if the node already have one");
rank = newrank;
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
sizeof(num_neighbors), "ReConnectLink failure 4");
for (int i = 0; i < num_neighbors; ++i) {
int nrank;
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
"ReConnectLink failure 4");
tree_neighbors[nrank] = 1;
}
// create listening socket
sock_listen.Create();
sock_listen.SetKeepAlive(true);
// http://deepix.github.io/2016/10/21/tcprst.html
sock_listen.SetLinger(0);
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
// work around processes bind to same port without set reuse option,
// start explore from slave_port + newrank towards end
port = sock_listen.TryBindHost(slave_port+ newrank%nport_trial, slave_port + nport_trial);
// if no port bindable, explore first half of range
if (port == -1) sock_listen.TryBindHost(slave_port, newrank% nport_trial + slave_port);
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
"ReConnectLink failure 4");
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
"ReConnectLink failure 4");
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
}
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
if (!sock_listen.IsClosed()) {
sock_listen.Close();
}
// create listening socket
sock_listen.Create();
sock_listen.SetKeepAlive(true);
// http://deepix.github.io/2016/10/21/tcprst.html
sock_listen.SetLinger(0);
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
// work around processes bind to same port without set reuse option,
// start explore from slave_port + newrank towards end
port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial);
// if no port bindable, explore first half of range
if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port);
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
sock_listen.Listen();
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
// get number of to connect and number of to accept nodes from tracker
int num_conn, num_accept, num_error = 1;
do {
// send over good links
std::vector<int> good_link;
for (size_t i = 0; i < all_links.size(); ++i) {
if (!all_links[i].sock.BadSocket()) {
good_link.push_back(static_cast<int>(all_links[i].rank));
} else {
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
}
}
int ngood = static_cast<int>(good_link.size());
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
"ReConnectLink failure 5");
for (size_t i = 0; i < good_link.size(); ++i) {
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
sizeof(good_link[i]), "ReConnectLink failure 6");
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
}
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
"ReConnectLink failure 7");
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
sizeof(num_accept), "ReConnectLink failure 8");
num_error = 0;
for (int i = 0; i < num_conn; ++i) {
LinkRecord r;
int hport, hrank;
std::string hname;
tracker.RecvStr(&hname);
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
"ReConnectLink failure 9");
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
"ReConnectLink failure 10");
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1;
r.sock.Close();
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 13");
utils::Check(hrank == r.rank,
"ReConnectLink failure, link rank inconsistent");
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
}
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
"ReConnectLink failure 14");
} while (num_error != 0);
// send back socket listening port to tracker
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
"ReConnectLink failure 14");
// close connection to tracker
tracker.Close();
// listen to incoming links
for (int i = 0; i < num_accept; ++i) {
LinkRecord r;
r.sock = sock_listen.Accept();
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 15");
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
"ReConnectLink failure 15");
bool match = false;
for (size_t i = 0; i < all_links.size(); ++i) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
}
}
if (!match) all_links.push_back(r);
}
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
this->parent_index = -1;
// setup tree links and ring structure
tree_links.plinks.clear();
for (size_t i = 0; i < all_links.size(); ++i) {
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
all_links[i].sock.SetNonBlock(true);
all_links[i].sock.SetKeepAlive(true);
if (tree_neighbors.count(all_links[i].rank) != 0) {
if (all_links[i].rank == parent_rank) {
parent_index = static_cast<int>(tree_links.plinks.size());
}
tree_links.plinks.push_back(&all_links[i]);
}
tree_links.plinks.push_back(&all_links[i]);
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
}
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
return false;
}
Assert(parent_rank == -1 || parent_index != -1,
"cannot find parent in the link");
Assert(prev_rank == -1 || ring_prev != NULL,
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure

View File

@ -38,9 +38,9 @@ class AllreduceBase : public IEngine {
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -369,7 +369,7 @@ class AllreduceBase : public IEngine {
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
void ReConnectLinks(const char *cmd = "start");
bool ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*

View File

@ -31,31 +31,41 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
if (num_global_replica == 0) {
result_buffer_round = -1;
bool AllreduceRobust::Init(int argc, char* argv[]) {
if (AllreduceBase::Init(argc, argv)) {
if (num_global_replica == 0) {
result_buffer_round = -1;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
return true;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
return false;
}
}
/*! \brief shutdown the engine */
void AllreduceRobust::Shutdown(void) {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"Shutdown: check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"Shutdown: check ack must return true");
bool AllreduceRobust::Shutdown(void) {
try {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"Shutdown: check point must return true");
// reset result buffer
resbuf.Clear();
seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"Shutdown: check ack must return true");
#if defined (__APPLE__)
sleep(1);
sleep(1);
#endif
AllreduceBase::Shutdown();
return AllreduceBase::Shutdown();
} catch (const std::exception& e) {
fprintf(stderr, "%s\n", e.what());
return false;
}
}
/*!
* \brief set parameters to the engine

View File

@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name

View File

@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable {
} // namespace c_api
} // namespace rabit
void RabitInit(int argc, char *argv[]) {
rabit::Init(argc, argv);
bool RabitInit(int argc, char *argv[]) {
return rabit::Init(argc, argv);
}
void RabitFinalize() {
rabit::Finalize();
bool RabitFinalize() {
return rabit::Finalize();
}
int RabitGetRank() {

View File

@ -43,23 +43,29 @@ struct ThreadLocalEntry {
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() == nullptr) {
e->initialized = true;
e->engine.reset(new Manager());
e->engine->Init(argc, argv);
return e->engine->Init(argc, argv);
} else {
return true;
}
}
/*! \brief finalize syncrhonization module */
void Finalize() {
bool Finalize() {
ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() != nullptr,
"rabit::Finalize engine is not initialized or already been finalized.");
e->engine->Shutdown();
e->engine.reset(nullptr);
e->initialized = false;
if (e->engine->Shutdown()) {
e->engine.reset(nullptr);
e->initialized = false;
return true;
} else {
return false;
}
}
/*! \brief singleton method to get engine */

View File

@ -82,10 +82,12 @@ class EmptyEngine : public IEngine {
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
bool Finalize(void) {
return true;
}
/*! \brief singleton method to get engine */

View File

@ -90,13 +90,25 @@ class MPIEngine : public IEngine {
// singleton sync manager
MPIEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
MPI::Init(argc, argv);
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
MPI::Init(argc, argv);
return true;
} catch (const std::exception& e) {
fprintf(stderr, " failed in MPI Init %s\n", e.what());
return false;
}
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
MPI::Finalize();
bool Finalize(void) {
try {
MPI::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */