return values in Init and Finalize (#96)

* make inti function return values

* address the comments
This commit is contained in:
Nan Zhu 2019-06-25 20:05:54 -07:00 committed by GitHub
parent fc85f776f4
commit 65b718a5e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 253 additions and 201 deletions

View File

@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
* from environment variables.
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL void RabitInit(int argc, char *argv[]);
RABIT_DLL bool RabitInit(int argc, char *argv[]);
/*!
* \brief finalize the rabit engine,
* call this function after you finished all jobs.
* \return true if rabit is initialized successfully otherwise false
*/
RABIT_DLL void RabitFinalize(void);
RABIT_DLL bool RabitFinalize(void);
/*!
* \brief get rank of current process

View File

@ -161,9 +161,9 @@ class IEngine {
};
/*! \brief initializes the engine module */
void Init(int argc, char *argv[]);
bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */
void Finalize(void);
bool Finalize(void);
/*! \brief singleton method to get engine */
IEngine *GetEngine(void);

View File

@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
} // namespace op
// intialize the rabit engine
inline void Init(int argc, char *argv[]) {
engine::Init(argc, argv);
inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline void Finalize(void) {
engine::Finalize();
inline bool Finalize(void) {
return engine::Finalize();
}
// get the rank of current process
inline int GetRank(void) {

View File

@ -73,12 +73,14 @@ struct BitOR;
* \brief initializes rabit, call this once at the beginning of your program
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true if initialized successfully, otherwise false
*/
inline void Init(int argc, char *argv[]);
inline bool Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
* \return true if finalized successfully, otherwise false
*/
inline void Finalize();
inline bool Finalize();
/*! \brief gets rank of the current process
* \return rank number of worker*/
inline int GetRank();

View File

@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) {
}
// initialization function
void AllreduceBase::Init(int argc, char* argv[]) {
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
@ -122,17 +122,18 @@ void AllreduceBase::Init(int argc, char* argv[]) {
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
this->ReConnectLinks();
return this->ReConnectLinks();
}
void AllreduceBase::Shutdown(void) {
bool AllreduceBase::Shutdown(void) {
try {
for (size_t i = 0; i < all_links.size(); ++i) {
all_links[i].sock.Close();
}
all_links.clear();
tree_links.plinks.clear();
if (tracker_uri == "NULL") return;
if (tracker_uri == "NULL") return true;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("shutdown"));
@ -140,6 +141,11 @@ void AllreduceBase::Shutdown(void) {
// close listening sockets
sock_listen.Close();
utils::TCPSocket::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
return false;
}
}
void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
@ -252,11 +258,12 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
bool AllreduceBase::ReConnectLinks(const char *cmd) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
rank = 0; world_size = 1; return true;
}
try {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string(cmd));
@ -345,7 +352,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
r.sock.Create();
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
num_error += 1; r.sock.Close(); continue;
num_error += 1;
r.sock.Close();
continue;
}
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
"ReConnectLink failure 12");
@ -358,7 +367,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_links[i].rank == hrank) {
Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
@ -384,7 +395,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
if (all_links[i].rank == r.rank) {
utils::Assert(all_links[i].sock.IsClosed(),
"Override a link that is active");
all_links[i].sock = r.sock; match = true; break;
all_links[i].sock = r.sock;
match = true;
break;
}
}
if (!match) all_links.push_back(r);
@ -413,6 +426,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
return false;
}
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure

View File

@ -38,9 +38,9 @@ class AllreduceBase : public IEngine {
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name
@ -369,7 +369,7 @@ class AllreduceBase : public IEngine {
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
void ReConnectLinks(const char *cmd = "start");
bool ReConnectLinks(const char *cmd = "start");
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*

View File

@ -31,22 +31,28 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
bool AllreduceRobust::Init(int argc, char* argv[]) {
if (AllreduceBase::Init(argc, argv)) {
if (num_global_replica == 0) {
result_buffer_round = -1;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
return true;
} else {
return false;
}
}
/*! \brief shutdown the engine */
void AllreduceRobust::Shutdown(void) {
bool AllreduceRobust::Shutdown(void) {
try {
// need to sync the exec before we shutdown, do a pesudo check point
// execute checkpoint, note: when checkpoint existing, load will not happen
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
"Shutdown: check point must return true");
// reset result buffer
resbuf.Clear(); seq_counter = 0;
resbuf.Clear();
seq_counter = 0;
// execute check ack step, load happens here
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
"Shutdown: check ack must return true");
@ -55,7 +61,11 @@ void AllreduceRobust::Shutdown(void) {
sleep(1);
#endif
AllreduceBase::Shutdown();
return AllreduceBase::Shutdown();
} catch (const std::exception& e) {
fprintf(stderr, "%s\n", e.what());
return false;
}
}
/*!
* \brief set parameters to the engine

View File

@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
virtual bool Shutdown(void);
/*!
* \brief set parameters to the engine
* \param name parameter name

View File

@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable {
} // namespace c_api
} // namespace rabit
void RabitInit(int argc, char *argv[]) {
rabit::Init(argc, argv);
bool RabitInit(int argc, char *argv[]) {
return rabit::Init(argc, argv);
}
void RabitFinalize() {
rabit::Finalize();
bool RabitFinalize() {
return rabit::Finalize();
}
int RabitGetRank() {

View File

@ -43,23 +43,29 @@ struct ThreadLocalEntry {
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
ThreadLocalEntry* e = EngineThreadLocal::Get();
if (e->engine.get() == nullptr) {
e->initialized = true;
e->engine.reset(new Manager());
e->engine->Init(argc, argv);
return e->engine->Init(argc, argv);
} else {
return true;
}
}
/*! \brief finalize syncrhonization module */
void Finalize() {
bool Finalize() {
ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() != nullptr,
"rabit::Finalize engine is not initialized or already been finalized.");
e->engine->Shutdown();
if (e->engine->Shutdown()) {
e->engine.reset(nullptr);
e->initialized = false;
return true;
} else {
return false;
}
}
/*! \brief singleton method to get engine */

View File

@ -82,10 +82,12 @@ class EmptyEngine : public IEngine {
EmptyEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
bool Finalize(void) {
return true;
}
/*! \brief singleton method to get engine */

View File

@ -90,13 +90,25 @@ class MPIEngine : public IEngine {
// singleton sync manager
MPIEngine manager;
/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
/*! \brief initialize the synchronization module */
bool Init(int argc, char *argv[]) {
try {
MPI::Init(argc, argv);
return true;
} catch (const std::exception& e) {
fprintf(stderr, " failed in MPI Init %s\n", e.what());
return false;
}
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
bool Finalize(void) {
try {
MPI::Finalize();
return true;
} catch (const std::exception& e) {
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
return false;
}
}
/*! \brief singleton method to get engine */