From 65b718a5e786bd7d0a850f3fa1df0dbdab023eb1 Mon Sep 17 00:00:00 2001 From: Nan Zhu Date: Tue, 25 Jun 2019 20:05:54 -0700 Subject: [PATCH] return values in Init and Finalize (#96) * make inti function return values * address the comments --- include/rabit/c_api.h | 6 +- include/rabit/internal/engine.h | 4 +- include/rabit/internal/rabit-inl.h | 8 +- include/rabit/rabit.h | 6 +- src/allreduce_base.cc | 322 +++++++++++++++-------------- src/allreduce_base.h | 6 +- src/allreduce_robust.cc | 44 ++-- src/allreduce_robust.h | 4 +- src/c_api.cc | 8 +- src/engine.cc | 18 +- src/engine_empty.cc | 6 +- src/engine_mpi.cc | 22 +- 12 files changed, 253 insertions(+), 201 deletions(-) diff --git a/include/rabit/c_api.h b/include/rabit/c_api.h index 22e73f7be..87c3d40a4 100644 --- a/include/rabit/c_api.h +++ b/include/rabit/c_api.h @@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*) * from environment variables. * \param argc number of arguments in argv * \param argv the array of input arguments + * \return true if rabit is initialized successfully otherwise false */ -RABIT_DLL void RabitInit(int argc, char *argv[]); +RABIT_DLL bool RabitInit(int argc, char *argv[]); /*! * \brief finalize the rabit engine, * call this function after you finished all jobs. + * \return true if rabit is initialized successfully otherwise false */ -RABIT_DLL void RabitFinalize(void); +RABIT_DLL bool RabitFinalize(void); /*! * \brief get rank of current process diff --git a/include/rabit/internal/engine.h b/include/rabit/internal/engine.h index 6a7dfe4a3..e8fce8195 100644 --- a/include/rabit/internal/engine.h +++ b/include/rabit/internal/engine.h @@ -161,9 +161,9 @@ class IEngine { }; /*! \brief initializes the engine module */ -void Init(int argc, char *argv[]); +bool Init(int argc, char *argv[]); /*! \brief finalizes the engine module */ -void Finalize(void); +bool Finalize(void); /*! \brief singleton method to get engine */ IEngine *GetEngine(void); diff --git a/include/rabit/internal/rabit-inl.h b/include/rabit/internal/rabit-inl.h index f556d62e6..88f20231e 100644 --- a/include/rabit/internal/rabit-inl.h +++ b/include/rabit/internal/rabit-inl.h @@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype & } // namespace op // intialize the rabit engine -inline void Init(int argc, char *argv[]) { - engine::Init(argc, argv); +inline bool Init(int argc, char *argv[]) { + return engine::Init(argc, argv); } // finalize the rabit engine -inline void Finalize(void) { - engine::Finalize(); +inline bool Finalize(void) { + return engine::Finalize(); } // get the rank of current process inline int GetRank(void) { diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index 83e8c58fe..ac7af5e95 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -73,12 +73,14 @@ struct BitOR; * \brief initializes rabit, call this once at the beginning of your program * \param argc number of arguments in argv * \param argv the array of input arguments + * \return true if initialized successfully, otherwise false */ -inline void Init(int argc, char *argv[]); +inline bool Init(int argc, char *argv[]); /*! * \brief finalizes the rabit engine, call this function after you finished with all the jobs + * \return true if finalized successfully, otherwise false */ -inline void Finalize(); +inline bool Finalize(); /*! \brief gets rank of the current process * \return rank number of worker*/ inline int GetRank(); diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index cdff4446e..fc5cc3e42 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) { } // initialization function -void AllreduceBase::Init(int argc, char* argv[]) { +bool AllreduceBase::Init(int argc, char* argv[]) { // setup from enviroment variables // handler to get variables from env for (size_t i = 0; i < env_vars.size(); ++i) { @@ -122,24 +122,30 @@ void AllreduceBase::Init(int argc, char* argv[]) { utils::Assert(all_links.size() == 0, "can only call Init once"); this->host_uri = utils::SockAddr::GetHostName(); // get information from tracker - this->ReConnectLinks(); + return this->ReConnectLinks(); } -void AllreduceBase::Shutdown(void) { - for (size_t i = 0; i < all_links.size(); ++i) { - all_links[i].sock.Close(); - } - all_links.clear(); - tree_links.plinks.clear(); +bool AllreduceBase::Shutdown(void) { + try { + for (size_t i = 0; i < all_links.size(); ++i) { + all_links[i].sock.Close(); + } + all_links.clear(); + tree_links.plinks.clear(); - if (tracker_uri == "NULL") return; - // notify tracker rank i have shutdown - utils::TCPSocket tracker = this->ConnectTracker(); - tracker.SendStr(std::string("shutdown")); - tracker.Close(); - // close listening sockets - sock_listen.Close(); - utils::TCPSocket::Finalize(); + if (tracker_uri == "NULL") return true; + // notify tracker rank i have shutdown + utils::TCPSocket tracker = this->ConnectTracker(); + tracker.SendStr(std::string("shutdown")); + tracker.Close(); + // close listening sockets + sock_listen.Close(); + utils::TCPSocket::Finalize(); + return true; + } catch (const std::exception& e) { + fprintf(stderr, "failed to shutdown due to %s\n", e.what()); + return false; + } } void AllreduceBase::TrackerPrint(const std::string &msg) { if (tracker_uri == "NULL") { @@ -252,167 +258,179 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up */ -void AllreduceBase::ReConnectLinks(const char *cmd) { +bool AllreduceBase::ReConnectLinks(const char *cmd) { // single node mode if (tracker_uri == "NULL") { - rank = 0; world_size = 1; return; + rank = 0; world_size = 1; return true; } - utils::TCPSocket tracker = this->ConnectTracker(); - tracker.SendStr(std::string(cmd)); + try { + utils::TCPSocket tracker = this->ConnectTracker(); + tracker.SendStr(std::string(cmd)); - // the rank of previous link, next link in ring - int prev_rank, next_rank; - // the rank of neighbors - std::map tree_neighbors; - using utils::Assert; - // get new ranks - int newrank, num_neighbors; - Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), + // the rank of previous link, next link in ring + int prev_rank, next_rank; + // the rank of neighbors + std::map tree_neighbors; + using utils::Assert; + // get new ranks + int newrank, num_neighbors; + Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank), "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\ + Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \ sizeof(parent_rank), "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), - "ReConnectLink failure 4"); - Assert(rank == -1 || newrank == rank, - "must keep rank to same if the node already have one"); - rank = newrank; - Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \ - sizeof(num_neighbors), "ReConnectLink failure 4"); - for (int i = 0; i < num_neighbors; ++i) { - int nrank; - Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), + Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 4"); - tree_neighbors[nrank] = 1; - } - Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), - "ReConnectLink failure 4"); - Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), - "ReConnectLink failure 4"); - - if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) { - if (!sock_listen.IsClosed()) { - sock_listen.Close(); + Assert(rank == -1 || newrank == rank, + "must keep rank to same if the node already have one"); + rank = newrank; + Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \ + sizeof(num_neighbors), "ReConnectLink failure 4"); + for (int i = 0; i < num_neighbors; ++i) { + int nrank; + Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank), + "ReConnectLink failure 4"); + tree_neighbors[nrank] = 1; } - // create listening socket - sock_listen.Create(); - sock_listen.SetKeepAlive(true); - // http://deepix.github.io/2016/10/21/tcprst.html - sock_listen.SetLinger(0); - // [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial) - // work around processes bind to same port without set reuse option, - // start explore from slave_port + newrank towards end - port = sock_listen.TryBindHost(slave_port+ newrank%nport_trial, slave_port + nport_trial); - // if no port bindable, explore first half of range - if (port == -1) sock_listen.TryBindHost(slave_port, newrank% nport_trial + slave_port); + Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank), + "ReConnectLink failure 4"); + Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank), + "ReConnectLink failure 4"); - utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); - sock_listen.Listen(); - } - - // get number of to connect and number of to accept nodes from tracker - int num_conn, num_accept, num_error = 1; - do { - // send over good links - std::vector good_link; - for (size_t i = 0; i < all_links.size(); ++i) { - if (!all_links[i].sock.BadSocket()) { - good_link.push_back(static_cast(all_links[i].rank)); - } else { - if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); + if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) { + if (!sock_listen.IsClosed()) { + sock_listen.Close(); } + // create listening socket + sock_listen.Create(); + sock_listen.SetKeepAlive(true); + // http://deepix.github.io/2016/10/21/tcprst.html + sock_listen.SetLinger(0); + // [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial) + // work around processes bind to same port without set reuse option, + // start explore from slave_port + newrank towards end + port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial); + // if no port bindable, explore first half of range + if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port); + + utils::Check(port != -1, "ReConnectLink fail to bind the ports specified"); + sock_listen.Listen(); } - int ngood = static_cast(good_link.size()); - Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), - "ReConnectLink failure 5"); - for (size_t i = 0; i < good_link.size(); ++i) { - Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \ + + // get number of to connect and number of to accept nodes from tracker + int num_conn, num_accept, num_error = 1; + do { + // send over good links + std::vector good_link; + for (size_t i = 0; i < all_links.size(); ++i) { + if (!all_links[i].sock.BadSocket()) { + good_link.push_back(static_cast(all_links[i].rank)); + } else { + if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close(); + } + } + int ngood = static_cast(good_link.size()); + Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood), + "ReConnectLink failure 5"); + for (size_t i = 0; i < good_link.size(); ++i) { + Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \ sizeof(good_link[i]), "ReConnectLink failure 6"); - } - Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), - "ReConnectLink failure 7"); - Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \ - sizeof(num_accept), "ReConnectLink failure 8"); - num_error = 0; - for (int i = 0; i < num_conn; ++i) { - LinkRecord r; - int hport, hrank; - std::string hname; - tracker.RecvStr(&hname); - Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), - "ReConnectLink failure 9"); - Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), - "ReConnectLink failure 10"); - - r.sock.Create(); - if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { - num_error += 1; r.sock.Close(); continue; } + Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn), + "ReConnectLink failure 7"); + Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \ + sizeof(num_accept), "ReConnectLink failure 8"); + num_error = 0; + for (int i = 0; i < num_conn; ++i) { + LinkRecord r; + int hport, hrank; + std::string hname; + tracker.RecvStr(&hname); + Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport), + "ReConnectLink failure 9"); + Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank), + "ReConnectLink failure 10"); + + r.sock.Create(); + if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) { + num_error += 1; + r.sock.Close(); + continue; + } + Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), + "ReConnectLink failure 12"); + Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), + "ReConnectLink failure 13"); + utils::Check(hrank == r.rank, + "ReConnectLink failure, link rank inconsistent"); + bool match = false; + for (size_t i = 0; i < all_links.size(); ++i) { + if (all_links[i].rank == hrank) { + Assert(all_links[i].sock.IsClosed(), + "Override a link that is active"); + all_links[i].sock = r.sock; + match = true; + break; + } + } + if (!match) all_links.push_back(r); + } + Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), + "ReConnectLink failure 14"); + } while (num_error != 0); + // send back socket listening port to tracker + Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), + "ReConnectLink failure 14"); + // close connection to tracker + tracker.Close(); + // listen to incoming links + for (int i = 0; i < num_accept; ++i) { + LinkRecord r; + r.sock = sock_listen.Accept(); Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 12"); + "ReConnectLink failure 15"); Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 13"); - utils::Check(hrank == r.rank, - "ReConnectLink failure, link rank inconsistent"); + "ReConnectLink failure 15"); bool match = false; for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == hrank) { - Assert(all_links[i].sock.IsClosed(), - "Override a link that is active"); - all_links[i].sock = r.sock; match = true; break; + if (all_links[i].rank == r.rank) { + utils::Assert(all_links[i].sock.IsClosed(), + "Override a link that is active"); + all_links[i].sock = r.sock; + match = true; + break; } } if (!match) all_links.push_back(r); } - Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error), - "ReConnectLink failure 14"); - } while (num_error != 0); - // send back socket listening port to tracker - Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port), - "ReConnectLink failure 14"); - // close connection to tracker - tracker.Close(); - // listen to incoming links - for (int i = 0; i < num_accept; ++i) { - LinkRecord r; - r.sock = sock_listen.Accept(); - Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank), - "ReConnectLink failure 15"); - Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank), - "ReConnectLink failure 15"); - bool match = false; - for (size_t i = 0; i < all_links.size(); ++i) { - if (all_links[i].rank == r.rank) { - utils::Assert(all_links[i].sock.IsClosed(), - "Override a link that is active"); - all_links[i].sock = r.sock; match = true; break; - } - } - if (!match) all_links.push_back(r); - } - this->parent_index = -1; - // setup tree links and ring structure - tree_links.plinks.clear(); - for (size_t i = 0; i < all_links.size(); ++i) { - utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); - // set the socket to non-blocking mode, enable TCP keepalive - all_links[i].sock.SetNonBlock(true); - all_links[i].sock.SetKeepAlive(true); - if (tree_neighbors.count(all_links[i].rank) != 0) { - if (all_links[i].rank == parent_rank) { - parent_index = static_cast(tree_links.plinks.size()); + this->parent_index = -1; + // setup tree links and ring structure + tree_links.plinks.clear(); + for (size_t i = 0; i < all_links.size(); ++i) { + utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket"); + // set the socket to non-blocking mode, enable TCP keepalive + all_links[i].sock.SetNonBlock(true); + all_links[i].sock.SetKeepAlive(true); + if (tree_neighbors.count(all_links[i].rank) != 0) { + if (all_links[i].rank == parent_rank) { + parent_index = static_cast(tree_links.plinks.size()); + } + tree_links.plinks.push_back(&all_links[i]); } - tree_links.plinks.push_back(&all_links[i]); + if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; + if (all_links[i].rank == next_rank) ring_next = &all_links[i]; } - if (all_links[i].rank == prev_rank) ring_prev = &all_links[i]; - if (all_links[i].rank == next_rank) ring_next = &all_links[i]; + Assert(parent_rank == -1 || parent_index != -1, + "cannot find parent in the link"); + Assert(prev_rank == -1 || ring_prev != NULL, + "cannot find prev ring in the link"); + Assert(next_rank == -1 || ring_next != NULL, + "cannot find next ring in the link"); + return true; + } catch (const std::exception& e) { + fprintf(stderr, "failed in ReconnectLink %s\n", e.what()); + return false; } - Assert(parent_rank == -1 || parent_index != -1, - "cannot find parent in the link"); - Assert(prev_rank == -1 || ring_prev != NULL, - "cannot find prev ring in the link"); - Assert(next_rank == -1 || ring_next != NULL, - "cannot find next ring in the link"); } /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure diff --git a/src/allreduce_base.h b/src/allreduce_base.h index b83cb0d02..bff53328e 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -38,9 +38,9 @@ class AllreduceBase : public IEngine { AllreduceBase(void); virtual ~AllreduceBase(void) {} // initialize the manager - virtual void Init(int argc, char* argv[]); + virtual bool Init(int argc, char* argv[]); // shutdown the engine - virtual void Shutdown(void); + virtual bool Shutdown(void); /*! * \brief set parameters to the engine * \param name parameter name @@ -369,7 +369,7 @@ class AllreduceBase : public IEngine { * this function is also used when the engine start up * \param cmd possible command to sent to tracker */ - void ReConnectLinks(const char *cmd = "start"); + bool ReConnectLinks(const char *cmd = "start"); /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index ce5a56162..be37f1659 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -31,31 +31,41 @@ AllreduceRobust::AllreduceRobust(void) { env_vars.push_back("rabit_global_replica"); env_vars.push_back("rabit_local_replica"); } -void AllreduceRobust::Init(int argc, char* argv[]) { - AllreduceBase::Init(argc, argv); - if (num_global_replica == 0) { - result_buffer_round = -1; +bool AllreduceRobust::Init(int argc, char* argv[]) { + if (AllreduceBase::Init(argc, argv)) { + if (num_global_replica == 0) { + result_buffer_round = -1; + } else { + result_buffer_round = std::max(world_size / num_global_replica, 1); + } + return true; } else { - result_buffer_round = std::max(world_size / num_global_replica, 1); + return false; } } /*! \brief shutdown the engine */ -void AllreduceRobust::Shutdown(void) { - // need to sync the exec before we shutdown, do a pesudo check point - // execute checkpoint, note: when checkpoint existing, load will not happen - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), - "Shutdown: check point must return true"); - // reset result buffer - resbuf.Clear(); seq_counter = 0; - // execute check ack step, load happens here - utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), - "Shutdown: check ack must return true"); +bool AllreduceRobust::Shutdown(void) { + try { + // need to sync the exec before we shutdown, do a pesudo check point + // execute checkpoint, note: when checkpoint existing, load will not happen + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp), + "Shutdown: check point must return true"); + // reset result buffer + resbuf.Clear(); + seq_counter = 0; + // execute check ack step, load happens here + utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp), + "Shutdown: check ack must return true"); #if defined (__APPLE__) - sleep(1); + sleep(1); #endif - AllreduceBase::Shutdown(); + return AllreduceBase::Shutdown(); + } catch (const std::exception& e) { + fprintf(stderr, "%s\n", e.what()); + return false; + } } /*! * \brief set parameters to the engine diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index c8860822d..66c24f183 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase { AllreduceRobust(void); virtual ~AllreduceRobust(void) {} // initialize the manager - virtual void Init(int argc, char* argv[]); + virtual bool Init(int argc, char* argv[]); /*! \brief shutdown the engine */ - virtual void Shutdown(void); + virtual bool Shutdown(void); /*! * \brief set parameters to the engine * \param name parameter name diff --git a/src/c_api.cc b/src/c_api.cc index 8c789c08b..5e8132cc9 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable { } // namespace c_api } // namespace rabit -void RabitInit(int argc, char *argv[]) { - rabit::Init(argc, argv); +bool RabitInit(int argc, char *argv[]) { + return rabit::Init(argc, argv); } -void RabitFinalize() { - rabit::Finalize(); +bool RabitFinalize() { + return rabit::Finalize(); } int RabitGetRank() { diff --git a/src/engine.cc b/src/engine.cc index b57e9445b..d2b94e1d1 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -43,23 +43,29 @@ struct ThreadLocalEntry { typedef ThreadLocalStore EngineThreadLocal; /*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { +bool Init(int argc, char *argv[]) { ThreadLocalEntry* e = EngineThreadLocal::Get(); if (e->engine.get() == nullptr) { e->initialized = true; e->engine.reset(new Manager()); - e->engine->Init(argc, argv); + return e->engine->Init(argc, argv); + } else { + return true; } } /*! \brief finalize syncrhonization module */ -void Finalize() { +bool Finalize() { ThreadLocalEntry* e = EngineThreadLocal::Get(); utils::Check(e->engine.get() != nullptr, "rabit::Finalize engine is not initialized or already been finalized."); - e->engine->Shutdown(); - e->engine.reset(nullptr); - e->initialized = false; + if (e->engine->Shutdown()) { + e->engine.reset(nullptr); + e->initialized = false; + return true; + } else { + return false; + } } /*! \brief singleton method to get engine */ diff --git a/src/engine_empty.cc b/src/engine_empty.cc index c3b210d9a..c917a3d9d 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -82,10 +82,12 @@ class EmptyEngine : public IEngine { EmptyEngine manager; /*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { +bool Init(int argc, char *argv[]) { + return true; } /*! \brief finalize syncrhonization module */ -void Finalize(void) { +bool Finalize(void) { + return true; } /*! \brief singleton method to get engine */ diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 6dc6cb7ca..d345b9486 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -90,13 +90,25 @@ class MPIEngine : public IEngine { // singleton sync manager MPIEngine manager; -/*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { - MPI::Init(argc, argv); +/*! \brief initialize the synchronization module */ +bool Init(int argc, char *argv[]) { + try { + MPI::Init(argc, argv); + return true; + } catch (const std::exception& e) { + fprintf(stderr, " failed in MPI Init %s\n", e.what()); + return false; + } } /*! \brief finalize syncrhonization module */ -void Finalize(void) { - MPI::Finalize(); +bool Finalize(void) { + try { + MPI::Finalize(); + return true; + } catch (const std::exception& e) { + fprintf(stderr, "failed in MPI shutdown %s\n", e.what()); + return false; + } } /*! \brief singleton method to get engine */