return values in Init and Finalize (#96)
* make inti function return values * address the comments
This commit is contained in:
parent
fc85f776f4
commit
65b718a5e7
@ -32,14 +32,16 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
|
|||||||
* from environment variables.
|
* from environment variables.
|
||||||
* \param argc number of arguments in argv
|
* \param argc number of arguments in argv
|
||||||
* \param argv the array of input arguments
|
* \param argv the array of input arguments
|
||||||
|
* \return true if rabit is initialized successfully otherwise false
|
||||||
*/
|
*/
|
||||||
RABIT_DLL void RabitInit(int argc, char *argv[]);
|
RABIT_DLL bool RabitInit(int argc, char *argv[]);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief finalize the rabit engine,
|
* \brief finalize the rabit engine,
|
||||||
* call this function after you finished all jobs.
|
* call this function after you finished all jobs.
|
||||||
|
* \return true if rabit is initialized successfully otherwise false
|
||||||
*/
|
*/
|
||||||
RABIT_DLL void RabitFinalize(void);
|
RABIT_DLL bool RabitFinalize(void);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get rank of current process
|
* \brief get rank of current process
|
||||||
|
|||||||
@ -161,9 +161,9 @@ class IEngine {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief initializes the engine module */
|
/*! \brief initializes the engine module */
|
||||||
void Init(int argc, char *argv[]);
|
bool Init(int argc, char *argv[]);
|
||||||
/*! \brief finalizes the engine module */
|
/*! \brief finalizes the engine module */
|
||||||
void Finalize(void);
|
bool Finalize(void);
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
IEngine *GetEngine(void);
|
IEngine *GetEngine(void);
|
||||||
|
|
||||||
|
|||||||
@ -103,12 +103,12 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
|
|||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
||||||
// intialize the rabit engine
|
// intialize the rabit engine
|
||||||
inline void Init(int argc, char *argv[]) {
|
inline bool Init(int argc, char *argv[]) {
|
||||||
engine::Init(argc, argv);
|
return engine::Init(argc, argv);
|
||||||
}
|
}
|
||||||
// finalize the rabit engine
|
// finalize the rabit engine
|
||||||
inline void Finalize(void) {
|
inline bool Finalize(void) {
|
||||||
engine::Finalize();
|
return engine::Finalize();
|
||||||
}
|
}
|
||||||
// get the rank of current process
|
// get the rank of current process
|
||||||
inline int GetRank(void) {
|
inline int GetRank(void) {
|
||||||
|
|||||||
@ -73,12 +73,14 @@ struct BitOR;
|
|||||||
* \brief initializes rabit, call this once at the beginning of your program
|
* \brief initializes rabit, call this once at the beginning of your program
|
||||||
* \param argc number of arguments in argv
|
* \param argc number of arguments in argv
|
||||||
* \param argv the array of input arguments
|
* \param argv the array of input arguments
|
||||||
|
* \return true if initialized successfully, otherwise false
|
||||||
*/
|
*/
|
||||||
inline void Init(int argc, char *argv[]);
|
inline bool Init(int argc, char *argv[]);
|
||||||
/*!
|
/*!
|
||||||
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
|
||||||
|
* \return true if finalized successfully, otherwise false
|
||||||
*/
|
*/
|
||||||
inline void Finalize();
|
inline bool Finalize();
|
||||||
/*! \brief gets rank of the current process
|
/*! \brief gets rank of the current process
|
||||||
* \return rank number of worker*/
|
* \return rank number of worker*/
|
||||||
inline int GetRank();
|
inline int GetRank();
|
||||||
|
|||||||
@ -57,7 +57,7 @@ AllreduceBase::AllreduceBase(void) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// initialization function
|
// initialization function
|
||||||
void AllreduceBase::Init(int argc, char* argv[]) {
|
bool AllreduceBase::Init(int argc, char* argv[]) {
|
||||||
// setup from enviroment variables
|
// setup from enviroment variables
|
||||||
// handler to get variables from env
|
// handler to get variables from env
|
||||||
for (size_t i = 0; i < env_vars.size(); ++i) {
|
for (size_t i = 0; i < env_vars.size(); ++i) {
|
||||||
@ -122,24 +122,30 @@ void AllreduceBase::Init(int argc, char* argv[]) {
|
|||||||
utils::Assert(all_links.size() == 0, "can only call Init once");
|
utils::Assert(all_links.size() == 0, "can only call Init once");
|
||||||
this->host_uri = utils::SockAddr::GetHostName();
|
this->host_uri = utils::SockAddr::GetHostName();
|
||||||
// get information from tracker
|
// get information from tracker
|
||||||
this->ReConnectLinks();
|
return this->ReConnectLinks();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AllreduceBase::Shutdown(void) {
|
bool AllreduceBase::Shutdown(void) {
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
try {
|
||||||
all_links[i].sock.Close();
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
}
|
all_links[i].sock.Close();
|
||||||
all_links.clear();
|
}
|
||||||
tree_links.plinks.clear();
|
all_links.clear();
|
||||||
|
tree_links.plinks.clear();
|
||||||
|
|
||||||
if (tracker_uri == "NULL") return;
|
if (tracker_uri == "NULL") return true;
|
||||||
// notify tracker rank i have shutdown
|
// notify tracker rank i have shutdown
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
tracker.SendStr(std::string("shutdown"));
|
tracker.SendStr(std::string("shutdown"));
|
||||||
tracker.Close();
|
tracker.Close();
|
||||||
// close listening sockets
|
// close listening sockets
|
||||||
sock_listen.Close();
|
sock_listen.Close();
|
||||||
utils::TCPSocket::Finalize();
|
utils::TCPSocket::Finalize();
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
fprintf(stderr, "failed to shutdown due to %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
void AllreduceBase::TrackerPrint(const std::string &msg) {
|
||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
@ -252,167 +258,179 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
|
|||||||
* \brief connect to the tracker to fix the the missing links
|
* \brief connect to the tracker to fix the the missing links
|
||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
*/
|
*/
|
||||||
void AllreduceBase::ReConnectLinks(const char *cmd) {
|
bool AllreduceBase::ReConnectLinks(const char *cmd) {
|
||||||
// single node mode
|
// single node mode
|
||||||
if (tracker_uri == "NULL") {
|
if (tracker_uri == "NULL") {
|
||||||
rank = 0; world_size = 1; return;
|
rank = 0; world_size = 1; return true;
|
||||||
}
|
}
|
||||||
utils::TCPSocket tracker = this->ConnectTracker();
|
try {
|
||||||
tracker.SendStr(std::string(cmd));
|
utils::TCPSocket tracker = this->ConnectTracker();
|
||||||
|
tracker.SendStr(std::string(cmd));
|
||||||
|
|
||||||
// the rank of previous link, next link in ring
|
// the rank of previous link, next link in ring
|
||||||
int prev_rank, next_rank;
|
int prev_rank, next_rank;
|
||||||
// the rank of neighbors
|
// the rank of neighbors
|
||||||
std::map<int, int> tree_neighbors;
|
std::map<int, int> tree_neighbors;
|
||||||
using utils::Assert;
|
using utils::Assert;
|
||||||
// get new ranks
|
// get new ranks
|
||||||
int newrank, num_neighbors;
|
int newrank, num_neighbors;
|
||||||
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
Assert(tracker.RecvAll(&newrank, sizeof(newrank)) == sizeof(newrank),
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) ==\
|
Assert(tracker.RecvAll(&parent_rank, sizeof(parent_rank)) == \
|
||||||
sizeof(parent_rank), "ReConnectLink failure 4");
|
sizeof(parent_rank), "ReConnectLink failure 4");
|
||||||
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
Assert(tracker.RecvAll(&world_size, sizeof(world_size)) == sizeof(world_size),
|
||||||
"ReConnectLink failure 4");
|
|
||||||
Assert(rank == -1 || newrank == rank,
|
|
||||||
"must keep rank to same if the node already have one");
|
|
||||||
rank = newrank;
|
|
||||||
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
|
||||||
sizeof(num_neighbors), "ReConnectLink failure 4");
|
|
||||||
for (int i = 0; i < num_neighbors; ++i) {
|
|
||||||
int nrank;
|
|
||||||
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
|
||||||
"ReConnectLink failure 4");
|
"ReConnectLink failure 4");
|
||||||
tree_neighbors[nrank] = 1;
|
Assert(rank == -1 || newrank == rank,
|
||||||
}
|
"must keep rank to same if the node already have one");
|
||||||
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
rank = newrank;
|
||||||
"ReConnectLink failure 4");
|
Assert(tracker.RecvAll(&num_neighbors, sizeof(num_neighbors)) == \
|
||||||
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
sizeof(num_neighbors), "ReConnectLink failure 4");
|
||||||
"ReConnectLink failure 4");
|
for (int i = 0; i < num_neighbors; ++i) {
|
||||||
|
int nrank;
|
||||||
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
|
Assert(tracker.RecvAll(&nrank, sizeof(nrank)) == sizeof(nrank),
|
||||||
if (!sock_listen.IsClosed()) {
|
"ReConnectLink failure 4");
|
||||||
sock_listen.Close();
|
tree_neighbors[nrank] = 1;
|
||||||
}
|
}
|
||||||
// create listening socket
|
Assert(tracker.RecvAll(&prev_rank, sizeof(prev_rank)) == sizeof(prev_rank),
|
||||||
sock_listen.Create();
|
"ReConnectLink failure 4");
|
||||||
sock_listen.SetKeepAlive(true);
|
Assert(tracker.RecvAll(&next_rank, sizeof(next_rank)) == sizeof(next_rank),
|
||||||
// http://deepix.github.io/2016/10/21/tcprst.html
|
"ReConnectLink failure 4");
|
||||||
sock_listen.SetLinger(0);
|
|
||||||
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
|
|
||||||
// work around processes bind to same port without set reuse option,
|
|
||||||
// start explore from slave_port + newrank towards end
|
|
||||||
port = sock_listen.TryBindHost(slave_port+ newrank%nport_trial, slave_port + nport_trial);
|
|
||||||
// if no port bindable, explore first half of range
|
|
||||||
if (port == -1) sock_listen.TryBindHost(slave_port, newrank% nport_trial + slave_port);
|
|
||||||
|
|
||||||
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
if (sock_listen == INVALID_SOCKET || sock_listen.AtMark()) {
|
||||||
sock_listen.Listen();
|
if (!sock_listen.IsClosed()) {
|
||||||
}
|
sock_listen.Close();
|
||||||
|
|
||||||
// get number of to connect and number of to accept nodes from tracker
|
|
||||||
int num_conn, num_accept, num_error = 1;
|
|
||||||
do {
|
|
||||||
// send over good links
|
|
||||||
std::vector<int> good_link;
|
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
|
||||||
if (!all_links[i].sock.BadSocket()) {
|
|
||||||
good_link.push_back(static_cast<int>(all_links[i].rank));
|
|
||||||
} else {
|
|
||||||
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
|
||||||
}
|
}
|
||||||
|
// create listening socket
|
||||||
|
sock_listen.Create();
|
||||||
|
sock_listen.SetKeepAlive(true);
|
||||||
|
// http://deepix.github.io/2016/10/21/tcprst.html
|
||||||
|
sock_listen.SetLinger(0);
|
||||||
|
// [slave_port, slave_port+1 .... slave_port + newrank ...slave_port + nport_trial)
|
||||||
|
// work around processes bind to same port without set reuse option,
|
||||||
|
// start explore from slave_port + newrank towards end
|
||||||
|
port = sock_listen.TryBindHost(slave_port + newrank % nport_trial, slave_port + nport_trial);
|
||||||
|
// if no port bindable, explore first half of range
|
||||||
|
if (port == -1) sock_listen.TryBindHost(slave_port, newrank % nport_trial + slave_port);
|
||||||
|
|
||||||
|
utils::Check(port != -1, "ReConnectLink fail to bind the ports specified");
|
||||||
|
sock_listen.Listen();
|
||||||
}
|
}
|
||||||
int ngood = static_cast<int>(good_link.size());
|
|
||||||
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
// get number of to connect and number of to accept nodes from tracker
|
||||||
"ReConnectLink failure 5");
|
int num_conn, num_accept, num_error = 1;
|
||||||
for (size_t i = 0; i < good_link.size(); ++i) {
|
do {
|
||||||
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
// send over good links
|
||||||
|
std::vector<int> good_link;
|
||||||
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
|
if (!all_links[i].sock.BadSocket()) {
|
||||||
|
good_link.push_back(static_cast<int>(all_links[i].rank));
|
||||||
|
} else {
|
||||||
|
if (!all_links[i].sock.IsClosed()) all_links[i].sock.Close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int ngood = static_cast<int>(good_link.size());
|
||||||
|
Assert(tracker.SendAll(&ngood, sizeof(ngood)) == sizeof(ngood),
|
||||||
|
"ReConnectLink failure 5");
|
||||||
|
for (size_t i = 0; i < good_link.size(); ++i) {
|
||||||
|
Assert(tracker.SendAll(&good_link[i], sizeof(good_link[i])) == \
|
||||||
sizeof(good_link[i]), "ReConnectLink failure 6");
|
sizeof(good_link[i]), "ReConnectLink failure 6");
|
||||||
}
|
|
||||||
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
|
||||||
"ReConnectLink failure 7");
|
|
||||||
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
|
||||||
sizeof(num_accept), "ReConnectLink failure 8");
|
|
||||||
num_error = 0;
|
|
||||||
for (int i = 0; i < num_conn; ++i) {
|
|
||||||
LinkRecord r;
|
|
||||||
int hport, hrank;
|
|
||||||
std::string hname;
|
|
||||||
tracker.RecvStr(&hname);
|
|
||||||
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
|
||||||
"ReConnectLink failure 9");
|
|
||||||
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
|
||||||
"ReConnectLink failure 10");
|
|
||||||
|
|
||||||
r.sock.Create();
|
|
||||||
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
|
||||||
num_error += 1; r.sock.Close(); continue;
|
|
||||||
}
|
}
|
||||||
|
Assert(tracker.RecvAll(&num_conn, sizeof(num_conn)) == sizeof(num_conn),
|
||||||
|
"ReConnectLink failure 7");
|
||||||
|
Assert(tracker.RecvAll(&num_accept, sizeof(num_accept)) == \
|
||||||
|
sizeof(num_accept), "ReConnectLink failure 8");
|
||||||
|
num_error = 0;
|
||||||
|
for (int i = 0; i < num_conn; ++i) {
|
||||||
|
LinkRecord r;
|
||||||
|
int hport, hrank;
|
||||||
|
std::string hname;
|
||||||
|
tracker.RecvStr(&hname);
|
||||||
|
Assert(tracker.RecvAll(&hport, sizeof(hport)) == sizeof(hport),
|
||||||
|
"ReConnectLink failure 9");
|
||||||
|
Assert(tracker.RecvAll(&hrank, sizeof(hrank)) == sizeof(hrank),
|
||||||
|
"ReConnectLink failure 10");
|
||||||
|
|
||||||
|
r.sock.Create();
|
||||||
|
if (!r.sock.Connect(utils::SockAddr(hname.c_str(), hport))) {
|
||||||
|
num_error += 1;
|
||||||
|
r.sock.Close();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||||
|
"ReConnectLink failure 12");
|
||||||
|
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||||
|
"ReConnectLink failure 13");
|
||||||
|
utils::Check(hrank == r.rank,
|
||||||
|
"ReConnectLink failure, link rank inconsistent");
|
||||||
|
bool match = false;
|
||||||
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
|
if (all_links[i].rank == hrank) {
|
||||||
|
Assert(all_links[i].sock.IsClosed(),
|
||||||
|
"Override a link that is active");
|
||||||
|
all_links[i].sock = r.sock;
|
||||||
|
match = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!match) all_links.push_back(r);
|
||||||
|
}
|
||||||
|
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
||||||
|
"ReConnectLink failure 14");
|
||||||
|
} while (num_error != 0);
|
||||||
|
// send back socket listening port to tracker
|
||||||
|
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
||||||
|
"ReConnectLink failure 14");
|
||||||
|
// close connection to tracker
|
||||||
|
tracker.Close();
|
||||||
|
// listen to incoming links
|
||||||
|
for (int i = 0; i < num_accept; ++i) {
|
||||||
|
LinkRecord r;
|
||||||
|
r.sock = sock_listen.Accept();
|
||||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
||||||
"ReConnectLink failure 12");
|
"ReConnectLink failure 15");
|
||||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
||||||
"ReConnectLink failure 13");
|
"ReConnectLink failure 15");
|
||||||
utils::Check(hrank == r.rank,
|
|
||||||
"ReConnectLink failure, link rank inconsistent");
|
|
||||||
bool match = false;
|
bool match = false;
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
if (all_links[i].rank == hrank) {
|
if (all_links[i].rank == r.rank) {
|
||||||
Assert(all_links[i].sock.IsClosed(),
|
utils::Assert(all_links[i].sock.IsClosed(),
|
||||||
"Override a link that is active");
|
"Override a link that is active");
|
||||||
all_links[i].sock = r.sock; match = true; break;
|
all_links[i].sock = r.sock;
|
||||||
|
match = true;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!match) all_links.push_back(r);
|
if (!match) all_links.push_back(r);
|
||||||
}
|
}
|
||||||
Assert(tracker.SendAll(&num_error, sizeof(num_error)) == sizeof(num_error),
|
|
||||||
"ReConnectLink failure 14");
|
|
||||||
} while (num_error != 0);
|
|
||||||
// send back socket listening port to tracker
|
|
||||||
Assert(tracker.SendAll(&port, sizeof(port)) == sizeof(port),
|
|
||||||
"ReConnectLink failure 14");
|
|
||||||
// close connection to tracker
|
|
||||||
tracker.Close();
|
|
||||||
// listen to incoming links
|
|
||||||
for (int i = 0; i < num_accept; ++i) {
|
|
||||||
LinkRecord r;
|
|
||||||
r.sock = sock_listen.Accept();
|
|
||||||
Assert(r.sock.SendAll(&rank, sizeof(rank)) == sizeof(rank),
|
|
||||||
"ReConnectLink failure 15");
|
|
||||||
Assert(r.sock.RecvAll(&r.rank, sizeof(r.rank)) == sizeof(r.rank),
|
|
||||||
"ReConnectLink failure 15");
|
|
||||||
bool match = false;
|
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
|
||||||
if (all_links[i].rank == r.rank) {
|
|
||||||
utils::Assert(all_links[i].sock.IsClosed(),
|
|
||||||
"Override a link that is active");
|
|
||||||
all_links[i].sock = r.sock; match = true; break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!match) all_links.push_back(r);
|
|
||||||
}
|
|
||||||
|
|
||||||
this->parent_index = -1;
|
this->parent_index = -1;
|
||||||
// setup tree links and ring structure
|
// setup tree links and ring structure
|
||||||
tree_links.plinks.clear();
|
tree_links.plinks.clear();
|
||||||
for (size_t i = 0; i < all_links.size(); ++i) {
|
for (size_t i = 0; i < all_links.size(); ++i) {
|
||||||
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
utils::Assert(!all_links[i].sock.BadSocket(), "ReConnectLink: bad socket");
|
||||||
// set the socket to non-blocking mode, enable TCP keepalive
|
// set the socket to non-blocking mode, enable TCP keepalive
|
||||||
all_links[i].sock.SetNonBlock(true);
|
all_links[i].sock.SetNonBlock(true);
|
||||||
all_links[i].sock.SetKeepAlive(true);
|
all_links[i].sock.SetKeepAlive(true);
|
||||||
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
if (tree_neighbors.count(all_links[i].rank) != 0) {
|
||||||
if (all_links[i].rank == parent_rank) {
|
if (all_links[i].rank == parent_rank) {
|
||||||
parent_index = static_cast<int>(tree_links.plinks.size());
|
parent_index = static_cast<int>(tree_links.plinks.size());
|
||||||
|
}
|
||||||
|
tree_links.plinks.push_back(&all_links[i]);
|
||||||
}
|
}
|
||||||
tree_links.plinks.push_back(&all_links[i]);
|
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
||||||
|
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
||||||
}
|
}
|
||||||
if (all_links[i].rank == prev_rank) ring_prev = &all_links[i];
|
Assert(parent_rank == -1 || parent_index != -1,
|
||||||
if (all_links[i].rank == next_rank) ring_next = &all_links[i];
|
"cannot find parent in the link");
|
||||||
|
Assert(prev_rank == -1 || ring_prev != NULL,
|
||||||
|
"cannot find prev ring in the link");
|
||||||
|
Assert(next_rank == -1 || ring_next != NULL,
|
||||||
|
"cannot find next ring in the link");
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
fprintf(stderr, "failed in ReconnectLink %s\n", e.what());
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
Assert(parent_rank == -1 || parent_index != -1,
|
|
||||||
"cannot find parent in the link");
|
|
||||||
Assert(prev_rank == -1 || ring_prev != NULL,
|
|
||||||
"cannot find prev ring in the link");
|
|
||||||
Assert(next_rank == -1 || ring_next != NULL,
|
|
||||||
"cannot find next ring in the link");
|
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
|
|||||||
@ -38,9 +38,9 @@ class AllreduceBase : public IEngine {
|
|||||||
AllreduceBase(void);
|
AllreduceBase(void);
|
||||||
virtual ~AllreduceBase(void) {}
|
virtual ~AllreduceBase(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual void Init(int argc, char* argv[]);
|
virtual bool Init(int argc, char* argv[]);
|
||||||
// shutdown the engine
|
// shutdown the engine
|
||||||
virtual void Shutdown(void);
|
virtual bool Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
@ -369,7 +369,7 @@ class AllreduceBase : public IEngine {
|
|||||||
* this function is also used when the engine start up
|
* this function is also used when the engine start up
|
||||||
* \param cmd possible command to sent to tracker
|
* \param cmd possible command to sent to tracker
|
||||||
*/
|
*/
|
||||||
void ReConnectLinks(const char *cmd = "start");
|
bool ReConnectLinks(const char *cmd = "start");
|
||||||
/*!
|
/*!
|
||||||
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
|
||||||
*
|
*
|
||||||
|
|||||||
@ -31,31 +31,41 @@ AllreduceRobust::AllreduceRobust(void) {
|
|||||||
env_vars.push_back("rabit_global_replica");
|
env_vars.push_back("rabit_global_replica");
|
||||||
env_vars.push_back("rabit_local_replica");
|
env_vars.push_back("rabit_local_replica");
|
||||||
}
|
}
|
||||||
void AllreduceRobust::Init(int argc, char* argv[]) {
|
bool AllreduceRobust::Init(int argc, char* argv[]) {
|
||||||
AllreduceBase::Init(argc, argv);
|
if (AllreduceBase::Init(argc, argv)) {
|
||||||
if (num_global_replica == 0) {
|
if (num_global_replica == 0) {
|
||||||
result_buffer_round = -1;
|
result_buffer_round = -1;
|
||||||
|
} else {
|
||||||
|
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
} else {
|
} else {
|
||||||
result_buffer_round = std::max(world_size / num_global_replica, 1);
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
void AllreduceRobust::Shutdown(void) {
|
bool AllreduceRobust::Shutdown(void) {
|
||||||
// need to sync the exec before we shutdown, do a pesudo check point
|
try {
|
||||||
// execute checkpoint, note: when checkpoint existing, load will not happen
|
// need to sync the exec before we shutdown, do a pesudo check point
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
// execute checkpoint, note: when checkpoint existing, load will not happen
|
||||||
"Shutdown: check point must return true");
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint, ActionSummary::kSpecialOp),
|
||||||
// reset result buffer
|
"Shutdown: check point must return true");
|
||||||
resbuf.Clear(); seq_counter = 0;
|
// reset result buffer
|
||||||
// execute check ack step, load happens here
|
resbuf.Clear();
|
||||||
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
seq_counter = 0;
|
||||||
"Shutdown: check ack must return true");
|
// execute check ack step, load happens here
|
||||||
|
utils::Assert(RecoverExec(NULL, 0, ActionSummary::kCheckAck, ActionSummary::kSpecialOp),
|
||||||
|
"Shutdown: check ack must return true");
|
||||||
|
|
||||||
#if defined (__APPLE__)
|
#if defined (__APPLE__)
|
||||||
sleep(1);
|
sleep(1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
AllreduceBase::Shutdown();
|
return AllreduceBase::Shutdown();
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
fprintf(stderr, "%s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
|
|||||||
@ -24,9 +24,9 @@ class AllreduceRobust : public AllreduceBase {
|
|||||||
AllreduceRobust(void);
|
AllreduceRobust(void);
|
||||||
virtual ~AllreduceRobust(void) {}
|
virtual ~AllreduceRobust(void) {}
|
||||||
// initialize the manager
|
// initialize the manager
|
||||||
virtual void Init(int argc, char* argv[]);
|
virtual bool Init(int argc, char* argv[]);
|
||||||
/*! \brief shutdown the engine */
|
/*! \brief shutdown the engine */
|
||||||
virtual void Shutdown(void);
|
virtual bool Shutdown(void);
|
||||||
/*!
|
/*!
|
||||||
* \brief set parameters to the engine
|
* \brief set parameters to the engine
|
||||||
* \param name parameter name
|
* \param name parameter name
|
||||||
|
|||||||
@ -162,12 +162,12 @@ struct WriteWrapper : public Serializable {
|
|||||||
} // namespace c_api
|
} // namespace c_api
|
||||||
} // namespace rabit
|
} // namespace rabit
|
||||||
|
|
||||||
void RabitInit(int argc, char *argv[]) {
|
bool RabitInit(int argc, char *argv[]) {
|
||||||
rabit::Init(argc, argv);
|
return rabit::Init(argc, argv);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RabitFinalize() {
|
bool RabitFinalize() {
|
||||||
rabit::Finalize();
|
return rabit::Finalize();
|
||||||
}
|
}
|
||||||
|
|
||||||
int RabitGetRank() {
|
int RabitGetRank() {
|
||||||
|
|||||||
@ -43,23 +43,29 @@ struct ThreadLocalEntry {
|
|||||||
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
bool Init(int argc, char *argv[]) {
|
||||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||||
if (e->engine.get() == nullptr) {
|
if (e->engine.get() == nullptr) {
|
||||||
e->initialized = true;
|
e->initialized = true;
|
||||||
e->engine.reset(new Manager());
|
e->engine.reset(new Manager());
|
||||||
e->engine->Init(argc, argv);
|
return e->engine->Init(argc, argv);
|
||||||
|
} else {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize() {
|
bool Finalize() {
|
||||||
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
ThreadLocalEntry* e = EngineThreadLocal::Get();
|
||||||
utils::Check(e->engine.get() != nullptr,
|
utils::Check(e->engine.get() != nullptr,
|
||||||
"rabit::Finalize engine is not initialized or already been finalized.");
|
"rabit::Finalize engine is not initialized or already been finalized.");
|
||||||
e->engine->Shutdown();
|
if (e->engine->Shutdown()) {
|
||||||
e->engine.reset(nullptr);
|
e->engine.reset(nullptr);
|
||||||
e->initialized = false;
|
e->initialized = false;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
|
|||||||
@ -82,10 +82,12 @@ class EmptyEngine : public IEngine {
|
|||||||
EmptyEngine manager;
|
EmptyEngine manager;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief intiialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
bool Init(int argc, char *argv[]) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize(void) {
|
bool Finalize(void) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
|
|||||||
@ -90,13 +90,25 @@ class MPIEngine : public IEngine {
|
|||||||
// singleton sync manager
|
// singleton sync manager
|
||||||
MPIEngine manager;
|
MPIEngine manager;
|
||||||
|
|
||||||
/*! \brief intiialize the synchronization module */
|
/*! \brief initialize the synchronization module */
|
||||||
void Init(int argc, char *argv[]) {
|
bool Init(int argc, char *argv[]) {
|
||||||
MPI::Init(argc, argv);
|
try {
|
||||||
|
MPI::Init(argc, argv);
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
fprintf(stderr, " failed in MPI Init %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
/*! \brief finalize syncrhonization module */
|
/*! \brief finalize syncrhonization module */
|
||||||
void Finalize(void) {
|
bool Finalize(void) {
|
||||||
MPI::Finalize();
|
try {
|
||||||
|
MPI::Finalize();
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception& e) {
|
||||||
|
fprintf(stderr, "failed in MPI shutdown %s\n", e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief singleton method to get engine */
|
/*! \brief singleton method to get engine */
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user