return values in Init and Finalize (#96)
* make inti function return values * address the comments
This commit is contained in:
parent
fc85f776f4
commit
65b718a5e7
@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
|
||||
* from environment variables.
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
||||
RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||
|
||||
/*!
|
||||
* \brief finalize the rabit engine,
|
||||
* call this function after you finished all jobs.
|
||||
* \return true if rabit is initialized successfully otherwise false
|
||||
*/
|
||||
RABIT_DLL void RabitFinalize(void);
|
||||
RABIT_DLL bool RabitFinalize(void);
|
||||
|
||||
/*!
|
||||
* \brief get rank of current process
|
||||
|
||||
@ -161,9 +161,9 @@ class IEngine {
|
||||
};
|
||||
|
||||
/*! \brief initializes the engine module */
|
||||
void Init(int argc, char *argv[]);
|
||||
bool Init(int argc, char *argv[]);
|
||||
/*! \brief finalizes the engine module */
|
||||
void Finalize(void);
|
||||
bool Finalize(void);
|
||||
/*! \brief singleton method to get engine */
|
||||
IEngine *GetEngine(void);
|
||||
|
||||
|
||||
@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
|
||||
} // namespace op
|
||||
|
||||
// intialize the rabit engine
|
||||
inline void Init(int argc, char *argv[]) {
|
||||
engine::Init(argc, argv);
|
||||
inline bool Init(int argc, char *argv[]) {
|
||||
return engine::Init(argc, argv);
|
||||
}
|
||||
// finalize the rabit engine
|
||||
inline void Finalize(void) {
|
||||
engine::Finalize();
|
||||
inline bool Finalize(void) {
|
||||
return engine::Finalize();
|
||||
}
|
||||
// get the rank of current process
|
||||
inline int GetRank(void) {
|
||||
|
||||
@ -73,12 +73,14 @@ struct BitOR;
|
||||
* \brief initializes rabit, call this once at the beginning of your program
|
||||
* \param argc number of arguments in argv
|
||||
* \param argv the array of input arguments
|
||||
* \return true if initialized successfully, otherwise false
|
||||
*/
|
||||
inline void Init(int argc, char *argv[]);
|
||||
inline bool Init(int argc, char *argv[]);
|
||||
/*!
|
||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||
* \return true if finalized successfully, otherwise false
|
||||
*/
|
||||
inline void Finalize();
|
||||
inline bool Finalize();
|
||||
/*! \brief gets rank of the current process
|
||||
* \return rank number of worker*/
|
||||
inline int GetRank();
|
||||
|
||||
@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) {
|
||||
}
|
||||
|
||||
// initialization function
|
||||
void AllreduceBase::Init(int argc, char* argv[]) {
|
||||
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||
// setup from enviroment variables
|
||||
// handler to get variables from env
|
||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||
@ -122,17 +122,18 @@ void AllreduceBase::Init(int argc, char* argv[]) {
|
||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||
this->host_uri = utils::SockAddr::GetHostName();
|
||||
// get information from tracker
|
||||
this->ReConnectLinks();
|
||||
return this->ReConnectLinks();
|
||||
}
|
||||
|
||||
void AllreduceBase::Shutdown(void) {
|
||||
bool AllreduceBase::Shutdown(void) {
|
||||
try {
|
||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||
all_links[i].sock.Close();
|
||||
}
|
||||
all_links.clear();
|
||||
tree_links.plinks.clear();
|
||||
|
||||
if (tracker_uri == "NULL") return;
|
||||
if (tracker_uri == "NULL") return true;
|
||||
// notify tracker rank i have shutdown
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string("shutdown"));
|
||||
@ -140,6 +141,11 @@ void AllreduceBase::Shutdown(void) {
|
||||
// close listening sockets
|
||||
sock_listen.Close();
|
||||
utils::TCPSocket::Finalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||
if (tracker_uri == "NULL") {
|
||||
@ -252,11 +258,12 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
||||
* \brief connect to the tracker to fix the the missing links
|
||||
* this function is also used when the engine start up
|
||||
*/
|
||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
// single node mode
|
||||
if (tracker_uri == "NULL") {
|
||||
rank = 0; world_size = 1; return;
|
||||
rank = 0; world_size = 1; return true;
|
||||
}
|
||||
try {
|
||||
utils::TCPSocket tracker = this->ConnectTracker();
|
||||
tracker.SendStr(std::string(cmd));
|
||||
|
||||
@ -345,7 +352,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
|
||||
r.sock.Create();
|
||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||
num_error += 1; r.sock.Close(); continue;
|
||||
num_error += 1;
|
||||
r.sock.Close();
|
||||
continue;
|
||||
}
|
||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||
"ReConnectLink failure 12");
|
||||
@ -358,7 +367,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
if (all_links[i].rank == hrank) {
|
||||
Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
all_links[i].sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
@ -384,7 +395,9 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
if (all_links[i].rank == r.rank) {
|
||||
utils::Assert(all_links[i].sock.IsClosed(),
|
||||
"Override a link that is active");
|
||||
all_links[i].sock = r.sock; match = true; break;
|
||||
all_links[i].sock = r.sock;
|
||||
match = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) all_links.push_back(r);
|
||||
@ -413,6 +426,11 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||
"cannot find prev ring in the link");
|
||||
Assert(next_rank == -1 || ring_next != NULL,
|
||||
"cannot find next ring in the link");
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
|
||||
@ -38,9 +38,9 @@ class AllreduceBase : public IEngine {
|
||||
AllreduceBase(void);
|
||||
virtual ~AllreduceBase(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
// shutdown the engine
|
||||
virtual void Shutdown(void);
|
||||
virtual bool Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
@ -369,7 +369,7 @@ class AllreduceBase : public IEngine {
|
||||
* this function is also used when the engine start up
|
||||
* \param cmd possible command to sent to tracker
|
||||
*/
|
||||
void ReConnectLinks(const char *cmd = "start");
|
||||
bool ReConnectLinks(const char *cmd = "start");
|
||||
/*!
|
||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||
*
|
||||
|
||||
@ -31,22 +31,28 @@ AllreduceRobust::AllreduceRobust(void) {
|
||||
env_vars.push_back("rabit_global_replica");
|
||||
env_vars.push_back("rabit_local_replica");
|
||||
}
|
||||
void AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
AllreduceBase::Init(argc, argv);
|
||||
bool AllreduceRobust::Init(int argc, char* argv[]) {
|
||||
if (AllreduceBase::Init(argc, argv)) {
|
||||
if (num_global_replica == 0) {
|
||||
result_buffer_round = -1;
|
||||
} else {
|
||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*! \brief shutdown the engine */
|
||||
void AllreduceRobust::Shutdown(void) {
|
||||
bool AllreduceRobust::Shutdown(void) {
|
||||
try {
|
||||
// need to sync the exec before we shutdown, do a pesudo check point
|
||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check point must return true");
|
||||
// reset result buffer
|
||||
resbuf.Clear(); seq_counter = 0;
|
||||
resbuf.Clear();
|
||||
seq_counter = 0;
|
||||
// execute check ack step, load happens here
|
||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||
"Shutdown: check ack must return true");
|
||||
@ -55,7 +61,11 @@ void AllreduceRobust::Shutdown(void) {
|
||||
sleep(1);
|
||||
#endif
|
||||
|
||||
AllreduceBase::Shutdown();
|
||||
return AllreduceBase::Shutdown();
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "%s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
|
||||
@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase {
|
||||
AllreduceRobust(void);
|
||||
virtual ~AllreduceRobust(void) {}
|
||||
// initialize the manager
|
||||
virtual void Init(int argc, char* argv[]);
|
||||
virtual bool Init(int argc, char* argv[]);
|
||||
/*! \brief shutdown the engine */
|
||||
virtual void Shutdown(void);
|
||||
virtual bool Shutdown(void);
|
||||
/*!
|
||||
* \brief set parameters to the engine
|
||||
* \param name parameter name
|
||||
|
||||
@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable {
|
||||
} // namespace c_api
|
||||
} // namespace rabit
|
||||
|
||||
void RabitInit(int argc, char *argv[]) {
|
||||
rabit::Init(argc, argv);
|
||||
bool RabitInit(int argc, char *argv[]) {
|
||||
return rabit::Init(argc, argv);
|
||||
}
|
||||
|
||||
void RabitFinalize() {
|
||||
rabit::Finalize();
|
||||
bool RabitFinalize() {
|
||||
return rabit::Finalize();
|
||||
}
|
||||
|
||||
int RabitGetRank() {
|
||||
|
||||
@ -43,23 +43,29 @@ struct ThreadLocalEntry {
|
||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
bool Init(int argc, char *argv[]) {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
if (e->engine.get() == nullptr) {
|
||||
e->initialized = true;
|
||||
e->engine.reset(new Manager());
|
||||
e->engine->Init(argc, argv);
|
||||
return e->engine->Init(argc, argv);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize() {
|
||||
bool Finalize() {
|
||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||
utils::Check(e->engine.get() != nullptr,
|
||||
"rabit::Finalize engine is not initialized or already been finalized.");
|
||||
e->engine->Shutdown();
|
||||
if (e->engine->Shutdown()) {
|
||||
e->engine.reset(nullptr);
|
||||
e->initialized = false;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
@ -82,10 +82,12 @@ class EmptyEngine : public IEngine {
|
||||
EmptyEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
bool Init(int argc, char *argv[]) {
|
||||
return true;
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
bool Finalize(void) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
@ -90,13 +90,25 @@ class MPIEngine : public IEngine {
|
||||
// singleton sync manager
|
||||
MPIEngine manager;
|
||||
|
||||
/*! \brief intiialize the synchronization module */
|
||||
void Init(int argc, char *argv[]) {
|
||||
/*! \brief initialize the synchronization module */
|
||||
bool Init(int argc, char *argv[]) {
|
||||
try {
|
||||
MPI::Init(argc, argv);
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, " failed in MPI Init %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
/*! \brief finalize syncrhonization module */
|
||||
void Finalize(void) {
|
||||
bool Finalize(void) {
|
||||
try {
|
||||
MPI::Finalize();
|
||||
return true;
|
||||
} catch (const std::exception& e) {
|
||||
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief singleton method to get engine */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user